[TOSA] Add legalization for fill, flip, and round (#3768)

- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and
aten.round
- Fix torchScalarToTosaTensor function to correctly convert Torch scalar
input to TOSA tensor
- Update xfail_sets.py with new e2e results
- Update basic.mlir with LIT tests for new ops


Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3749/head
Justin Ngo 2024-10-07 10:28:26 -07:00 committed by GitHub
parent f4840ed886
commit b08d08682f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 299 additions and 57 deletions

View File

@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant"); "Unable to extract the scalar constant");
int64_t numElem = 1;
for (int64_t dim : dshape)
numElem *= dim;
if (isa<mlir::FloatType>(dtype)) { if (isa<mlir::FloatType>(dtype)) {
tosaTensor = tosa::getConstTensor<float>(rewriter, op, tosaTensor =
(isFloat ? doubleValue : intValue), tosa::getConstTensor<float>(
dshape, dtype) rewriter, op,
.value(); SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
dshape, dtype)
.value();
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) { } else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth(); auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64) if (w != 1 && w != 32 && w != 64)
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
} }
bool d = isFloat ? static_cast<bool>(doubleValue) bool d = isFloat ? static_cast<bool>(doubleValue)
: static_cast<bool>(intValue); : static_cast<bool>(intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<bool>(
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<bool>(numElem, d), dshape)
.value();
} else if (w == 32) { } else if (w == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
} }
int32_t d = isFloat ? static_cast<int32_t>(doubleValue) int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue); : static_cast<int32_t>(intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<int32_t>(
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
.value();
} else if (w == 64) { } else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
"of destination type"); "of destination type");
} }
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue); int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<int64_t>(
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
.value();
} }
} else { } else {
return rewriter.notifyMatchFailure(op, "Usupported element type"); return rewriter.notifyMatchFailure(op, "Usupported element type");
@ -5320,7 +5329,7 @@ public:
}; };
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> { class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
@ -5336,18 +5345,48 @@ public:
op, "Only Tensor types with static shapes are currently supported"); op, "Only Tensor types with static shapes are currently supported");
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
}
Value constOp;
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), constOp, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Supplied value must be a Scalar constant");
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp); Value fillValueTargetTensor;
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
// Reshape value tensor to have same rank and shape as input
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
auto fillValue = adaptor.getValue();
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
if (!fillValueType)
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
auto fillValueElemTy = fillValueType.getElementType();
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
auto fillValueMatchedInputRankType = RankedTensorType::get(
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
fillValueElemTy);
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), fillValueMatchedInputRankType, fillValue,
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
fillValueElemTy),
fillValueMatchedInputRankTensor.getResult(),
makeShapeTorchCompatible(outType.getShape()));
} else {
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Fill value must be a scalar constant");
}
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
fillValueTargetTensor);
return success(); return success();
} }
@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
return success(); return success();
} }
// Legalization for aten.flip
template <>
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
AtenFlipOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are currently supported");
SmallVector<int64_t> dims;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(
op, "Only constant dims are currently supported");
auto selfRank = selfTy.getRank();
auto resultTy = getTypeConverter()->convertType(op.getType());
Value result = self;
for (auto &dim : dims) {
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
static_cast<int32_t>(dim));
}
rewriter.replaceOp(op, result);
return success();
}
// Legalization for aten.round:
// Rounds elements of input to the nearest integer.
// Implements "round half to even" to break ties when a number is equidistant
// from two integers.
template <>
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
AtenRoundOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To round to the nearest integer, we will consider the fractional part of
// the input element (= input element - integer part of element). If the
// fractional part is smaller than 0.5, round the number down. If the
// fractional part is 0.5, apply "round half to even" rule. If the fractional
// part is greater than 0.5, round up.
//
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
// res = floor(input)
// else:
// res = ceil(input)
auto self = adaptor.getSelf();
auto selfTy = dyn_cast<TensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
auto resultTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto boolTy =
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
auto resultElemTy = resultTy.getElementType();
auto oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
auto two =
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
auto floorInput =
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
// input - floor(input)
auto fractionalPart = rewriter.create<tosa::SubOp>(
op->getLoc(), resultTy, self, floorInput.getResult());
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
auto floorDivResult = rewriter.create<tosa::FloorOp>(
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
// (floor(input) // 2) * 2
auto evenComparison = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
auto floorInputEven = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
// (frac == 0.5) && (floor(input) % 2 == 0)
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
floorInputEven.getResult());
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
fracEqualOneHalfCond.getResult());
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
ceilInput.getResult());
return success();
}
// Template to create supporting diagonal mask tensor for aten.diagonal // Template to create supporting diagonal mask tensor for aten.diagonal
template <typename T> template <typename T>
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -6283,11 +6444,13 @@ public:
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN #undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ #define INSERT_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); INSERT_FILL_PATTERN(AtenFill_ScalarOp);
#undef INSERT_FILL_SCALAR_PATTERN INSERT_FILL_PATTERN(AtenFillScalarOp);
INSERT_FILL_PATTERN(AtenFillTensorOp);
#undef INSERT_FILL_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ #define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
@ -6359,6 +6522,8 @@ public:
INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp); INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1663,6 +1663,22 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"FlipModuleStaticShape_basic",
"FlipModule_basic",
"FlipNegativeIndexModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"AtenLinalgCrossBroadcast_basic", "AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic", "AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossFloat_basic", "AtenLinalgCrossFloat_basic",
@ -1819,7 +1835,6 @@ TOSA_PASS_SET = {
"ArangeStartOutModule_basic", "ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic", "ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic", "ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic", "ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic", "ArangeFloatModule_basic",
@ -2120,7 +2135,6 @@ TOSA_PASS_SET = {
"NormScalarOptDimModule_basic", "NormScalarOptDimModule_basic",
"NumToTensorFloatModule_basic", "NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic", "NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic", "NumpyTRank1Module_basic",
"NumpyTRank2Module_basic", "NumpyTRank2Module_basic",
"NumpyTRankNDynamicModule_basic", "NumpyTRankNDynamicModule_basic",
@ -2132,7 +2146,6 @@ TOSA_PASS_SET = {
"OnesModuleInt_basic", "OnesModuleInt_basic",
"PadModule_basic", "PadModule_basic",
"PadWithNoneValModule_basic", "PadWithNoneValModule_basic",
"Permute0RankModule_basic",
"PermuteModule_basic", "PermuteModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic", "PrimListUnpackNumMismatchModule_basic",
@ -2171,7 +2184,6 @@ TOSA_PASS_SET = {
"ScalarTensorInt64Module_basic", "ScalarTensorInt64Module_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic",
"SiluModule_basic", "SiluModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic", "SplitTensorLastSmallerModule_basic",
@ -3222,6 +3234,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
} }
FX_IMPORTER_TOSA_XFAIL_SET = { FX_IMPORTER_TOSA_XFAIL_SET = {
"ArangeZeroElementOutputModule_basic",
"NumpyTRank0Module_basic",
"Permute0RankModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic", "ChunkListUnpackUneven_Module_basic",
@ -3240,11 +3258,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"HstackBasicFloatModule_basic", "HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic", "HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic", "HstackBasicIntModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"AtenIntMM_basic", "AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic",
@ -3263,7 +3276,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluTrainStaticModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic", "IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpool3dModulePad0_basic",
@ -3342,8 +3354,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenMmQuint8_basic", "AtenMmQuint8_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
@ -3504,20 +3514,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"EqIntModule_basic", "EqIntModule_basic",
"ExpandModule_basic", "ExpandModule_basic",
"ExponentialModule_basic", "ExponentialModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"Fill_TensorFloat64WithInt64_basic",
"FlipModuleStaticShape_basic",
"FlipModule_basic",
"FlipNegativeIndexModule_basic",
"FloatImplicitModule_basic", "FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic", "FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic", "FullLikeModuleInt3D_basic",
@ -3847,9 +3843,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"VarMeanUnbiasedModule_basic", "VarMeanUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"ZeroFloat32Module_basic", "VisionTransformerModule_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_falsePinMemory", "ZerosLikeModule_falsePinMemory",
} }
@ -3862,6 +3856,12 @@ ONNX_TOSA_CRASHING_SET = {
} }
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"ArangeZeroElementOutputModule_basic",
"LinspaceEmptyModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"TrilIndicesAllZerosModule_basic",
"TriuIndicesAllZerosModule_basic",
"ElementwiseCreateComplexModule_basic", "ElementwiseCreateComplexModule_basic",
"ReduceAllDimFloatModule_basic", "ReduceAllDimFloatModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic", "AdaptiveMaxPool1dDimOneStatic_basic",
@ -4026,8 +4026,6 @@ ONNX_TOSA_XFAIL_SET = {
"AtenPolarDoubleModule_basic", "AtenPolarDoubleModule_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
@ -4071,8 +4069,6 @@ ONNX_TOSA_XFAIL_SET = {
"BucketizeTensorFloatModule_basic", "BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic", "BucketizeTensorModule_basic",
"BucketizeTensorOutInt32RightModule_basic", "BucketizeTensorOutInt32RightModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CeilFloatModule_basic", "CeilFloatModule_basic",
"ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic",

View File

@ -1917,3 +1917,84 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
return %0 : !torch.vtensor<[4,5,2],f32> return %0 : !torch.vtensor<[4,5,2],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.flip(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
// CHECK: }
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
return %1 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.round(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
// CHECK: }
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}