[TOSA] Fix rsub; add clamp.Tensor, avg_pool1d, max_pool1d, prims.collapse (#3855)

- Fix aten.rsub.Scalar legalization with appropriate type casting
- Add legalization for aten.clamp.Tensor
- Resolve some unexpected test failures from PyTorch update by adding
legalization for the following ops:
  + aten.avg_pool1d
  + aten.max_pool1d
  + torch.prims.collapse
- Update xfail_sets with new e2e results
- Add new LIT tests to basic.mlir


Change-Id: I9762c7d36ca0b0f75ca68d0c71d7f5d5309a96ad

---------

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3860/head
Justin Ngo 2024-11-07 14:09:43 -08:00 committed by GitHub
parent 8519ecc4d7
commit b6f04fa32b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 481 additions and 55 deletions

View File

@ -2072,25 +2072,29 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub");
auto resultTy =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto resultElemTy = resultTy.getElementType();
self = tosa::promoteType(rewriter, self, resultTy);
Value otherTensor, alphaTensor;
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
selfTy.getElementType(), {})))
resultElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Rsub operation");
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
alphaTensor, selfTy.getElementType(),
alphaTensor, resultElemTy,
/*checkForUnity=*/true)))
return failure();
auto multTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
auto multTensor = rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self,
alphaTensor, /*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(
op, getTypeConverter()->convertType(op.getType()), otherTensor,
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, resultTy, otherTensor,
multTensor);
return success();
@ -4730,6 +4734,108 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return success();
}
// Legalization for aten.clamp.Tensor
template <>
LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
AtenClampTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// We are not using tosa.clamp to lower aten.clamp.Tensor, as
// aten.clamp.Tensor's min and max attributes are tensors that can have size
// greater than 1, which is not compatible with tosa.clamp.
//
// Instead, we use the following formula:
// yi = min(max(xi, min_valuei), max_valuei)
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfElemTy = selfType.getElementType();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
// Get min tensor. If None, there is no lower bound.
Value min;
if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) {
min = adaptor.getMin();
} else {
min =
TypeSwitch<Type, Value>(selfElemTy)
.Case<mlir::FloatType>([&](auto) {
return tosa::getConstTensor<float>(
rewriter, op, std::numeric_limits<float>::lowest(), {},
selfElemTy)
.value();
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 8:
return tosa::getConstTensor<int8_t>(
rewriter, op, std::numeric_limits<int8_t>::min(), {})
.value();
case 32:
return tosa::getConstTensor<int32_t>(
rewriter, op, std::numeric_limits<int32_t>::min(),
{})
.value();
case 64:
return tosa::getConstTensor<int64_t>(
rewriter, op, std::numeric_limits<int64_t>::min(),
{})
.value();
}
llvm_unreachable("Invalid integer width");
});
}
// Get max tensor. If None, there is no upper bound.
Value max;
if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) {
max = adaptor.getMax();
} else {
max =
TypeSwitch<Type, Value>(selfElemTy)
.Case<mlir::FloatType>([&](auto) {
return tosa::getConstTensor<float>(
rewriter, op, std::numeric_limits<float>::max(), {},
selfElemTy)
.value();
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 8:
return tosa::getConstTensor<int8_t>(
rewriter, op, std::numeric_limits<int8_t>::max(), {})
.value();
case 32:
return tosa::getConstTensor<int32_t>(
rewriter, op, std::numeric_limits<int32_t>::max(),
{})
.value();
case 64:
return tosa::getConstTensor<int64_t>(
rewriter, op, std::numeric_limits<int64_t>::max(),
{})
.value();
}
llvm_unreachable("Invalid integer width");
});
}
// max(xi, min_valuei)
auto minThresholdCheck = tosa::createBinaryOpAndCast<tosa::MaximumOp>(
rewriter, op, resultType, self, min);
// yi = min(max(xi, min_valuei), max_valuei)
auto result = tosa::createBinaryOpAndCast<tosa::MinimumOp>(
rewriter, op, resultType, minThresholdCheck, max);
rewriter.replaceOp(op, result);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
AtenArangeStartStepOp op, OpAdaptor adaptor,
@ -5236,11 +5342,29 @@ public:
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
op, rewriter, pooledOutput);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
Value result = transposedOutput;
auto resultTy = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
transposedOutput);
op.getType()));
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
auto resultShape = resultTy.getShape();
auto resultElemTy = resultTy.getElementType();
result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(resultShape),
resultElemTy),
transposedOutput,
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, resultTy,
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
// op.getType()),
result);
return success();
}
@ -5387,6 +5511,12 @@ static LogicalResult getOutputTypeAndPoolingParameters(
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for pooling op unsupported");
// Expand kernel size parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
kernelSizeInts.push_back(1);
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(
op, "Non-const stride for pooling op unsupported");
@ -5394,13 +5524,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
// list during import. For such a case, the stride value is the kernel size.
// See:
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
if (strideInts.empty())
if (strideInts.empty()) {
strideInts.assign(kernelSizeInts);
} else {
// Expand stride parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
strideInts.push_back(1);
}
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
return rewriter.notifyMatchFailure(
op, "Non-const padding factor for pooling op unsupported");
// Expand padding parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
paddingInts[1], paddingInts[1]};
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
@ -5456,6 +5599,68 @@ public:
}
};
// Legalization for aten.max_pool1d
class ConvertAtenMaxPool1dOp
: public ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp> {
public:
using ConvertAtenPoolingBaseOp<AtenMaxPool1dOp,
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
DenseI64ArrayAttr &kernel,
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
auto self = adaptor.getSelf();
// Not a RankedTensorType
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type inputs are supported");
auto selfShape = selfTy.getShape();
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for MaxPool1d should have rank 3");
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
rank4Shape.push_back(1);
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
SmallVector<int64_t> dilationArray;
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationArray)))
return rewriter.notifyMatchFailure(
op, "Non-const dilation for pooling op unsupported.");
// TOSA pooling only supports unit dilation.
if (dilationArray[0] > 1)
return rewriter.notifyMatchFailure(
op, "Cannot process non-unit pooling dilation.");
// Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp
dilationArray.push_back(1);
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
tosa::MaxPool2dOp>(
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
kernel, stride, pad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");
// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
return success();
}
};
class ConvertAtenAvgPool2dOp
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
public:
@ -5504,6 +5709,68 @@ public:
}
};
// Legalization for aten.avg_pool1d
class ConvertAtenAvgPool1dOp
: public ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp> {
public:
using ConvertAtenPoolingBaseOp<AtenAvgPool1dOp,
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
DenseI64ArrayAttr &kernel,
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
auto self = adaptor.getSelf();
// Not a RankedTensorType
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type inputs are supported");
auto selfShape = selfTy.getShape();
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for MaxPool1d should have rank 3");
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
rank4Shape.push_back(1);
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}
SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
kernel, stride, pad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");
// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
return success();
}
};
// Ref: Error checking based on the Torch to LinAlg lowering
template <typename AtenOpT, int fillVal>
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
@ -6880,6 +7147,49 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
return success();
}
// Legalization for torch.prims.collapse
template <>
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
PrimsCollapseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getA();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape();
int64_t start, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "Only constant int start value is supported");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(
op, "Only constant int end value is supported");
// Identity case
if (start == end) {
rewriter.replaceOp(op, self);
return success();
}
// Technically, I should calculate the output shape based on the input shape,
// start value, and end value. However, that would just give the same result
// as me taking the result shape straight from resultType and applying
// tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach
// here, which is more simple and quicker.
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, self,
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -7101,9 +7411,15 @@ public:
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool1dOp>();
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp>();
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
@ -7199,6 +7515,8 @@ public:
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1744,6 +1744,23 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AdaptiveMaxPool1dDimOneStatic_basic",
"CollapseAllDimensionsModule_basic",
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseFracModule_basic",
"ElementwiseLdexpModule_basic",
"ElementwiseSignbitIntModule_basic",
"Exp2StaticIntModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"RsubIntModule_noalpha_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
@ -3373,9 +3390,10 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
@ -3519,11 +3537,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"BoolIntTrueModule_basic",
"BroadcastDynamicDimModule_basic",
"CeilFloatModule_basic",
"CollapseAllDimensionsModule_basic",
"CollapseFullDynamicModule_basic",
"CollapsePartialDynamicModule_basic",
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
@ -3585,10 +3598,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
@ -3784,7 +3793,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"RollModule_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
@ -3897,16 +3905,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IouOfModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"OneHotModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
@ -3927,6 +3931,16 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"ElementwiseCopysignModule_basic",
"ElementwiseFracModule_basic",
"ElementwiseLdexpModule_basic",
"ElementwiseSignbitIntModule_basic",
"ElementwiseSignbitModule_basic",
"Exp2StaticIntModule_basic",
"NllLossStaticModule_basic",
"NllLossStaticModule_mean_basic",
"NllLossStaticModule_sum_basic",
"NllLossStaticModule_weight_basic",
"Exp2StaticModule_basic",
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
@ -3950,7 +3964,6 @@ ONNX_TOSA_XFAIL_SET = {
"TriuIndicesAllZerosModule_basic",
"ElementwiseCreateComplexModule_basic",
"ReduceAllDimFloatModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
@ -4029,7 +4042,6 @@ ONNX_TOSA_XFAIL_SET = {
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
@ -4285,10 +4297,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
@ -4335,7 +4343,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
@ -4414,8 +4421,6 @@ ONNX_TOSA_XFAIL_SET = {
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardtanhBackward_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
@ -4463,7 +4468,6 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
@ -4474,8 +4478,6 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateDynamicModule_scales_recompute_bilinear",
@ -4503,10 +4505,7 @@ ONNX_TOSA_XFAIL_SET = {
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
@ -4607,7 +4606,6 @@ ONNX_TOSA_XFAIL_SET = {
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalFunctionalModule_basic",
"NormalizeModule_basic",
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"NumelModule_basic",
@ -4730,7 +4728,6 @@ ONNX_TOSA_XFAIL_SET = {
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeCollapseModule_basic",

View File

@ -2258,16 +2258,8 @@ func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1:
// -----
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
// CHECK: }
// CHECK-LABEL: torch.aten.uniform$basic
// CHECK: tosa.const
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
%float1.000000e00 = torch.constant.float 1.000000e+00
%float1.000000e01 = torch.constant.float 1.000000e+01
@ -2313,3 +2305,122 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
return %2 : !torch.vtensor<[3,3],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 64, 112, 1>} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32>
// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array<i64: 3, 1>, pad = array<i64: 1, 0, 0, 0>, stride = array<i64: 2, 1>} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32>
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 64, 56>} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32>
// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32>
// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32>
// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32>
// CHECK: }
func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56],f32>
return %4 : !torch.vtensor<[1,64,56],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 512, 10, 1>} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array<i64: 1, 512, 10>} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32>
// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32>
// CHECK: }
func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
// CHECK: %[[VAL_6:.*]] = torch.constant.none
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
// CHECK: }
func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
%none = torch.constant.none
%0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32>
%1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
%2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
}
// -----
// CHECK-LABEL: func.func @torch.prims.collapse$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 2, 12>} : (tensor<2x3x4xf32>) -> tensor<2x12xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32>
// CHECK: }
func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}