mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix rsub; add clamp.Tensor, avg_pool1d, max_pool1d, prims.collapse (#3855)
- Fix aten.rsub.Scalar legalization with appropriate type casting - Add legalization for aten.clamp.Tensor - Resolve some unexpected test failures from PyTorch update by adding legalization for the following ops: + aten.avg_pool1d + aten.max_pool1d + torch.prims.collapse - Update xfail_sets with new e2e results - Add new LIT tests to basic.mlir Change-Id: I9762c7d36ca0b0f75ca68d0c71d7f5d5309a96ad --------- Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3860/head
parent
8519ecc4d7
commit
b6f04fa32b
|
@ -2072,25 +2072,29 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||||
|
|
||||||
|
auto resultTy =
|
||||||
|
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
self = tosa::promoteType(rewriter, self, resultTy);
|
||||||
|
|
||||||
Value otherTensor, alphaTensor;
|
Value otherTensor, alphaTensor;
|
||||||
|
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||||
selfTy.getElementType(), {})))
|
resultElemTy, {})))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Currently only scalar constants are supported for "
|
op, "Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA Rsub operation");
|
"conversion in TOSA Rsub operation");
|
||||||
|
|
||||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
||||||
alphaTensor, selfTy.getElementType(),
|
alphaTensor, resultElemTy,
|
||||||
/*checkForUnity=*/true)))
|
/*checkForUnity=*/true)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
auto multTensor = rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self,
|
||||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
|
||||||
alphaTensor, /*shift=*/0);
|
alphaTensor, /*shift=*/0);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, resultTy, otherTensor,
|
||||||
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
|
||||||
multTensor);
|
multTensor);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -4730,6 +4734,108 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.clamp.Tensor
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
|
||||||
|
AtenClampTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// We are not using tosa.clamp to lower aten.clamp.Tensor, as
|
||||||
|
// aten.clamp.Tensor's min and max attributes are tensors that can have size
|
||||||
|
// greater than 1, which is not compatible with tosa.clamp.
|
||||||
|
//
|
||||||
|
// Instead, we use the following formula:
|
||||||
|
// yi = min(max(xi, min_valuei), max_valuei)
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
|
// Get min tensor. If None, there is no lower bound.
|
||||||
|
Value min;
|
||||||
|
if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) {
|
||||||
|
min = adaptor.getMin();
|
||||||
|
} else {
|
||||||
|
min =
|
||||||
|
TypeSwitch<Type, Value>(selfElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return tosa::getConstTensor<float>(
|
||||||
|
rewriter, op, std::numeric_limits<float>::lowest(), {},
|
||||||
|
selfElemTy)
|
||||||
|
.value();
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 8:
|
||||||
|
return tosa::getConstTensor<int8_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int8_t>::min(), {})
|
||||||
|
.value();
|
||||||
|
case 32:
|
||||||
|
return tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int32_t>::min(),
|
||||||
|
{})
|
||||||
|
.value();
|
||||||
|
case 64:
|
||||||
|
return tosa::getConstTensor<int64_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int64_t>::min(),
|
||||||
|
{})
|
||||||
|
.value();
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get max tensor. If None, there is no upper bound.
|
||||||
|
Value max;
|
||||||
|
if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) {
|
||||||
|
max = adaptor.getMax();
|
||||||
|
} else {
|
||||||
|
max =
|
||||||
|
TypeSwitch<Type, Value>(selfElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return tosa::getConstTensor<float>(
|
||||||
|
rewriter, op, std::numeric_limits<float>::max(), {},
|
||||||
|
selfElemTy)
|
||||||
|
.value();
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 8:
|
||||||
|
return tosa::getConstTensor<int8_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int8_t>::max(), {})
|
||||||
|
.value();
|
||||||
|
case 32:
|
||||||
|
return tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int32_t>::max(),
|
||||||
|
{})
|
||||||
|
.value();
|
||||||
|
case 64:
|
||||||
|
return tosa::getConstTensor<int64_t>(
|
||||||
|
rewriter, op, std::numeric_limits<int64_t>::max(),
|
||||||
|
{})
|
||||||
|
.value();
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// max(xi, min_valuei)
|
||||||
|
auto minThresholdCheck = tosa::createBinaryOpAndCast<tosa::MaximumOp>(
|
||||||
|
rewriter, op, resultType, self, min);
|
||||||
|
|
||||||
|
// yi = min(max(xi, min_valuei), max_valuei)
|
||||||
|
auto result = tosa::createBinaryOpAndCast<tosa::MinimumOp>(
|
||||||
|
rewriter, op, resultType, minThresholdCheck, max);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
AtenArangeStartStepOp op, OpAdaptor adaptor,
|
AtenArangeStartStepOp op, OpAdaptor adaptor,
|
||||||
|
@ -5236,11 +5342,29 @@ public:
|
||||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
|
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
|
||||||
op, rewriter, pooledOutput);
|
op, rewriter, pooledOutput);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
Value result = transposedOutput;
|
||||||
op,
|
auto resultTy = dyn_cast<TensorType>(
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()));
|
||||||
transposedOutput);
|
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
|
||||||
|
auto resultShape = resultTy.getShape();
|
||||||
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
result = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(resultShape),
|
||||||
|
resultElemTy),
|
||||||
|
transposedOutput,
|
||||||
|
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
op, resultTy,
|
||||||
|
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
// op.getType()),
|
||||||
|
result);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -5387,6 +5511,12 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Non-const kernel_size for pooling op unsupported");
|
op, "Non-const kernel_size for pooling op unsupported");
|
||||||
|
|
||||||
|
// Expand kernel size parameter to size 2 to be compatible with
|
||||||
|
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||||
|
kernelSizeInts.push_back(1);
|
||||||
|
|
||||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Non-const stride for pooling op unsupported");
|
op, "Non-const stride for pooling op unsupported");
|
||||||
|
@ -5394,13 +5524,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
||||||
// list during import. For such a case, the stride value is the kernel size.
|
// list during import. For such a case, the stride value is the kernel size.
|
||||||
// See:
|
// See:
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||||
if (strideInts.empty())
|
if (strideInts.empty()) {
|
||||||
strideInts.assign(kernelSizeInts);
|
strideInts.assign(kernelSizeInts);
|
||||||
|
} else {
|
||||||
|
// Expand stride parameter to size 2 to be compatible with
|
||||||
|
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||||
|
strideInts.push_back(1);
|
||||||
|
}
|
||||||
|
|
||||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Non-const padding factor for pooling op unsupported");
|
op, "Non-const padding factor for pooling op unsupported");
|
||||||
|
|
||||||
|
// Expand padding parameter to size 2 to be compatible with
|
||||||
|
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||||
|
paddingInts.push_back(0);
|
||||||
|
|
||||||
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
|
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
|
||||||
paddingInts[1], paddingInts[1]};
|
paddingInts[1], paddingInts[1]};
|
||||||
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
|
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
|
||||||
|
@ -5456,6 +5599,68 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Legalization for aten.max_pool1d
|
||||||
|
class ConvertAtenMaxPool1dOp
|
||||||
|
: public ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenPoolingBaseOp<AtenMaxPool1dOp,
|
||||||
|
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||||
|
LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Value &input,
|
||||||
|
DenseI64ArrayAttr &kernel,
|
||||||
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
|
Type &outputTy) const override {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a RankedTensorType
|
||||||
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor type inputs are supported");
|
||||||
|
auto selfShape = selfTy.getShape();
|
||||||
|
|
||||||
|
// Expected a rank 3 input tensor
|
||||||
|
if (selfTy.getRank() != 3)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Input tensor for MaxPool1d should have rank 3");
|
||||||
|
|
||||||
|
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
|
||||||
|
SmallVector<int64_t> rank4Shape(selfShape);
|
||||||
|
rank4Shape.push_back(1);
|
||||||
|
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
|
||||||
|
selfTy.getElementType()),
|
||||||
|
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
|
||||||
|
|
||||||
|
SmallVector<int64_t> dilationArray;
|
||||||
|
if (!matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilationArray)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Non-const dilation for pooling op unsupported.");
|
||||||
|
// TOSA pooling only supports unit dilation.
|
||||||
|
if (dilationArray[0] > 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Cannot process non-unit pooling dilation.");
|
||||||
|
|
||||||
|
// Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp
|
||||||
|
dilationArray.push_back(1);
|
||||||
|
|
||||||
|
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
|
||||||
|
tosa::MaxPool2dOp>(
|
||||||
|
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
|
||||||
|
kernel, stride, pad)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "invalid pooling parameters or input type");
|
||||||
|
|
||||||
|
// Transpose to xHWC
|
||||||
|
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
|
||||||
|
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertAtenAvgPool2dOp
|
class ConvertAtenAvgPool2dOp
|
||||||
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
|
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -5504,6 +5709,68 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Legalization for aten.avg_pool1d
|
||||||
|
class ConvertAtenAvgPool1dOp
|
||||||
|
: public ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenPoolingBaseOp<AtenAvgPool1dOp,
|
||||||
|
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||||
|
LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Value &input,
|
||||||
|
DenseI64ArrayAttr &kernel,
|
||||||
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
|
Type &outputTy) const override {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a RankedTensorType
|
||||||
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor type inputs are supported");
|
||||||
|
auto selfShape = selfTy.getShape();
|
||||||
|
|
||||||
|
// Expected a rank 3 input tensor
|
||||||
|
if (selfTy.getRank() != 3)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Input tensor for MaxPool1d should have rank 3");
|
||||||
|
|
||||||
|
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
|
||||||
|
SmallVector<int64_t> rank4Shape(selfShape);
|
||||||
|
rank4Shape.push_back(1);
|
||||||
|
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
|
||||||
|
selfTy.getElementType()),
|
||||||
|
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
|
||||||
|
|
||||||
|
// Currently, we can not represent `count_include_pad` with the existing
|
||||||
|
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||||
|
// wrong answers (SWA) when the `count_include_pad` value is `true.`
|
||||||
|
bool countIncludePad;
|
||||||
|
if (!matchPattern(op.getCountIncludePad(),
|
||||||
|
m_TorchConstantBool(&countIncludePad)) ||
|
||||||
|
countIncludePad) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
|
||||||
|
"`count_include_pad` value should be `False`.");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||||
|
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
|
||||||
|
tosa::AvgPool2dOp>(
|
||||||
|
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
|
||||||
|
kernel, stride, pad)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "invalid pooling parameters or input type");
|
||||||
|
|
||||||
|
// Transpose to xHWC
|
||||||
|
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
|
||||||
|
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Ref: Error checking based on the Torch to LinAlg lowering
|
// Ref: Error checking based on the Torch to LinAlg lowering
|
||||||
template <typename AtenOpT, int fillVal>
|
template <typename AtenOpT, int fillVal>
|
||||||
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
@ -6880,6 +7147,49 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for torch.prims.collapse
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||||
|
PrimsCollapseOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getA();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultShape = resultType.getShape();
|
||||||
|
|
||||||
|
int64_t start, end;
|
||||||
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only constant int start value is supported");
|
||||||
|
|
||||||
|
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only constant int end value is supported");
|
||||||
|
|
||||||
|
// Identity case
|
||||||
|
if (start == end) {
|
||||||
|
rewriter.replaceOp(op, self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Technically, I should calculate the output shape based on the input shape,
|
||||||
|
// start value, and end value. However, that would just give the same result
|
||||||
|
// as me taking the result shape straight from resultType and applying
|
||||||
|
// tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach
|
||||||
|
// here, which is more simple and quicker.
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
|
op, resultType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -7101,9 +7411,15 @@ public:
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||||
|
|
||||||
|
target.addIllegalOp<AtenMaxPool1dOp>();
|
||||||
|
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||||
|
|
||||||
|
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||||
|
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
|
||||||
|
|
||||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||||
|
@ -7199,6 +7515,8 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
||||||
|
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1744,6 +1744,23 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
|
"CollapseAllDimensionsModule_basic",
|
||||||
|
"CollapseRank1DynamicModule_basic",
|
||||||
|
"CollapseStaticModule_basic",
|
||||||
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampMinTensorIntModule_basic",
|
||||||
|
"ElementwiseClampTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampTensorIntModule_basic",
|
||||||
|
"ElementwiseFracModule_basic",
|
||||||
|
"ElementwiseLdexpModule_basic",
|
||||||
|
"ElementwiseSignbitIntModule_basic",
|
||||||
|
"Exp2StaticIntModule_basic",
|
||||||
|
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||||
|
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||||
|
"MaxPool1dStaticModule_basic",
|
||||||
|
"RepeatInterleaveSelfIntModule_basic",
|
||||||
|
"RsubIntModule_noalpha_basic",
|
||||||
"Aten_TrilinearModuleSumAllDims_basic",
|
"Aten_TrilinearModuleSumAllDims_basic",
|
||||||
"Aten_TrilinearModuleSumdims_basic",
|
"Aten_TrilinearModuleSumdims_basic",
|
||||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||||
|
@ -3373,9 +3390,10 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"ElementwiseCopysignModule_basic",
|
||||||
|
"ElementwiseSignbitModule_basic",
|
||||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
|
||||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||||
|
@ -3519,11 +3537,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"BoolIntTrueModule_basic",
|
"BoolIntTrueModule_basic",
|
||||||
"BroadcastDynamicDimModule_basic",
|
"BroadcastDynamicDimModule_basic",
|
||||||
"CeilFloatModule_basic",
|
"CeilFloatModule_basic",
|
||||||
"CollapseAllDimensionsModule_basic",
|
|
||||||
"CollapseFullDynamicModule_basic",
|
|
||||||
"CollapsePartialDynamicModule_basic",
|
|
||||||
"CollapseRank1DynamicModule_basic",
|
|
||||||
"CollapseStaticModule_basic",
|
|
||||||
"ConstantBoolParameterModule_basic",
|
"ConstantBoolParameterModule_basic",
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
|
@ -3585,10 +3598,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampMinTensorIntModule_basic",
|
|
||||||
"ElementwiseClampTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampTensorIntModule_basic",
|
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
|
@ -3784,7 +3793,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_right0",
|
"ReplicationPad2dModule_right0",
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RollModule_basic",
|
"RollModule_basic",
|
||||||
"RsubIntModule_noalpha_basic",
|
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
|
@ -3897,16 +3905,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
"IouOfModule_basic",
|
"IouOfModule_basic",
|
||||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
|
||||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
|
||||||
"MaxPool1dStaticModule_basic",
|
|
||||||
"MeshgridIndexingIJ_basic",
|
"MeshgridIndexingIJ_basic",
|
||||||
"MeshgridIndexingXY_basic",
|
"MeshgridIndexingXY_basic",
|
||||||
"Meshgrid_basic",
|
"Meshgrid_basic",
|
||||||
"OneHotModule_basic",
|
"OneHotModule_basic",
|
||||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||||
"ReduceFrobeniusNormModule_basic",
|
"ReduceFrobeniusNormModule_basic",
|
||||||
"RepeatInterleaveSelfIntModule_basic",
|
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
|
@ -3927,6 +3931,16 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"ElementwiseCopysignModule_basic",
|
||||||
|
"ElementwiseFracModule_basic",
|
||||||
|
"ElementwiseLdexpModule_basic",
|
||||||
|
"ElementwiseSignbitIntModule_basic",
|
||||||
|
"ElementwiseSignbitModule_basic",
|
||||||
|
"Exp2StaticIntModule_basic",
|
||||||
|
"NllLossStaticModule_basic",
|
||||||
|
"NllLossStaticModule_mean_basic",
|
||||||
|
"NllLossStaticModule_sum_basic",
|
||||||
|
"NllLossStaticModule_weight_basic",
|
||||||
"Exp2StaticModule_basic",
|
"Exp2StaticModule_basic",
|
||||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
|
@ -3950,7 +3964,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"TriuIndicesAllZerosModule_basic",
|
"TriuIndicesAllZerosModule_basic",
|
||||||
"ElementwiseCreateComplexModule_basic",
|
"ElementwiseCreateComplexModule_basic",
|
||||||
"ReduceAllDimFloatModule_basic",
|
"ReduceAllDimFloatModule_basic",
|
||||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
|
@ -4029,7 +4042,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||||
"AdaptiveAvgPool2dDynamic_basic",
|
"AdaptiveAvgPool2dDynamic_basic",
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
@ -4285,10 +4297,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||||
"ElementwiseBitwiseXorModule_basic",
|
"ElementwiseBitwiseXorModule_basic",
|
||||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
|
||||||
"ElementwiseClampMinModule_basic",
|
|
||||||
"ElementwiseClampModule_basic",
|
|
||||||
"ElementwiseClampTensorInt8Module_basic",
|
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
|
@ -4335,7 +4343,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
"ElementwiseRelu6Module_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
|
@ -4414,8 +4421,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"GtFloatIntModule_basic",
|
"GtFloatIntModule_basic",
|
||||||
"GtIntModule_basic",
|
"GtIntModule_basic",
|
||||||
"HBC_basic",
|
"HBC_basic",
|
||||||
"HardTanhIntModule_basic",
|
|
||||||
"HardTanhModule_basic",
|
|
||||||
"HardtanhBackward_basic",
|
"HardtanhBackward_basic",
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
|
@ -4463,7 +4468,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
"IndexTensorModule3dInput_basic",
|
"IndexTensorModule3dInput_basic",
|
||||||
"IndexTensorModule_basic",
|
"IndexTensorModule_basic",
|
||||||
"IndexTensorMultiIndexStaticModule_basic",
|
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
"IndexTensorMultiInputContiguousCenter_basic",
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
|
@ -4474,8 +4478,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
"IndexTensorMultiInput_basic",
|
"IndexTensorMultiInput_basic",
|
||||||
"IndexTensorSelectDimModule_basic",
|
"IndexTensorSelectDimModule_basic",
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||||
|
@ -4503,10 +4505,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"Matmul_matvec",
|
"Matmul_matvec",
|
||||||
"Matmul_vecmat",
|
"Matmul_vecmat",
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
|
||||||
"MaxPool1dStaticModule_basic",
|
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
"MaxPool2dModule_basic",
|
"MaxPool2dModule_basic",
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
@ -4607,7 +4606,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"NormScalarOptDimKeepDimModule_basic",
|
"NormScalarOptDimKeepDimModule_basic",
|
||||||
"NormScalarOptDimModule_basic",
|
"NormScalarOptDimModule_basic",
|
||||||
"NormalFunctionalModule_basic",
|
"NormalFunctionalModule_basic",
|
||||||
"NormalizeModule_basic",
|
|
||||||
"NumToTensorFloatModule_basic",
|
"NumToTensorFloatModule_basic",
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
|
@ -4730,7 +4728,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_right0",
|
"ReplicationPad2dModule_right0",
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"ResNet18Module_basic",
|
"ResNet18Module_basic",
|
||||||
"ResNet18StaticModule_basic",
|
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeCollapseModule_basic",
|
"ReshapeCollapseModule_basic",
|
||||||
|
|
|
@ -2258,16 +2258,8 @@ func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1:
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
|
// CHECK-LABEL: torch.aten.uniform$basic
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
// CHECK: tosa.const
|
||||||
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
|
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
|
||||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
|
|
||||||
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
|
||||||
// CHECK: }
|
|
||||||
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||||
%float1.000000e00 = torch.constant.float 1.000000e+00
|
%float1.000000e00 = torch.constant.float 1.000000e+00
|
||||||
%float1.000000e01 = torch.constant.float 1.000000e+01
|
%float1.000000e01 = torch.constant.float 1.000000e+01
|
||||||
|
@ -2313,3 +2305,122 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
|
||||||
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
|
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
|
||||||
return %2 : !torch.vtensor<[3,3],f32>
|
return %2 : !torch.vtensor<[3,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 64, 112, 1>} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array<i64: 3, 1>, pad = array<i64: 1, 0, 0, 0>, stride = array<i64: 2, 1>} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 64, 56>} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32>
|
||||||
|
// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56],f32>
|
||||||
|
return %4 : !torch.vtensor<[1,64,56],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 512, 10, 1>} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array<i64: 1, 512, 10>} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
|
||||||
|
// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
|
||||||
|
return %3 : !torch.vtensor<[1,512,10],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||||
|
// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32>
|
||||||
|
%1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
|
||||||
|
%2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
|
||||||
|
return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.prims.collapse$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 2, 12>} : (tensor<2x3x4xf32>) -> tensor<2x12xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,12],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue