[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)

Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3689/head
justin-ngo-arm 2024-09-05 11:27:29 -07:00 committed by GitHub
parent b790061b69
commit d4b5e05ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 277 additions and 97 deletions

View File

@ -22,6 +22,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric> #include <numeric>
#include <optional> #include <optional>
@ -5385,6 +5386,114 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return success(); return success();
} }
// Template to create support tril mask tensor for aten.tril
// legalization
template <typename T>
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal) {
SmallVector<T> vec;
for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) {
// Positive diagonal value includes as many diagonals above the main
// diagonal, while negative diagonal value excludes as many diagonals
// below the main diagonal.
if (i >= j - diagonal) {
vec.push_back(static_cast<T>(1));
} else {
vec.push_back(static_cast<T>(0));
}
}
}
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
}
// Function to get tril mask tensor based on input type
// for aten.tril legalization
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal, Type type) {
return TypeSwitch<Type, Value>(type)
.Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
case 32:
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});
}
// Legalization for aten.tril
template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
AtenTrilOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a ranked tensor type
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported");
// Rank below 2 not accepted
auto selfRank = selfType.getRank();
if (selfRank <= 1)
return rewriter.notifyMatchFailure(
op, "Rank 0 and 1 are not accepted as they cause underflow");
if (!selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Currently only static shapes are supported");
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!resultType)
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
// Get height, width of input tensor, and diagonal arg to create
// a const mask tensor to multiply with input.
// This mask tensor has the same height and width of input tensor
// and consists of 1's for the lower triangle part and 0's for the rest.
// For example, with h=4, w=6, diagonal=1:
// tensor([[1, 1, 0, 0, 0, 0],
// [1, 1, 1, 0, 0, 0],
// [1, 1, 1, 1, 0, 0],
// [1, 1, 1, 1, 1, 0]])
auto selfShape = selfType.getShape();
int64_t h = selfShape[selfRank - 2];
int64_t w = selfShape[selfRank - 1];
int64_t diagonal;
if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal)))
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
// Define shape for mask tensor based on rank
SmallVector<int64_t> constShape;
for (auto i = 0; i < selfRank - 2; i++)
constShape.push_back(1);
constShape.push_back(h);
constShape.push_back(w);
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
resultType.getElementType());
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
/*shift=*/0);
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -5638,6 +5747,7 @@ public:
INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp); INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -58,8 +58,10 @@ from .xfail_sets import (
FX_IMPORTER_CRASHING_SET, FX_IMPORTER_CRASHING_SET,
FX_IMPORTER_STABLEHLO_XFAIL_SET, FX_IMPORTER_STABLEHLO_XFAIL_SET,
FX_IMPORTER_STABLEHLO_CRASHING_SET, FX_IMPORTER_STABLEHLO_CRASHING_SET,
FX_IMPORTER_TOSA_CRASHING_SET,
FX_IMPORTER_TOSA_XFAIL_SET, FX_IMPORTER_TOSA_XFAIL_SET,
ONNX_TOSA_XFAIL_SET, ONNX_TOSA_XFAIL_SET,
ONNX_TOSA_CRASHING_SET,
) )
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
@ -191,7 +193,7 @@ def main():
elif args.config == "fx_importer_tosa": elif args.config == "fx_importer_tosa":
config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa") config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
crashing_set = set() crashing_set = FX_IMPORTER_TOSA_CRASHING_SET
elif args.config == "torchdynamo": elif args.config == "torchdynamo":
# TODO: Enanble runtime verification and extend crashing set. # TODO: Enanble runtime verification and extend crashing set.
config = TorchDynamoTestConfig( config = TorchDynamoTestConfig(
@ -206,7 +208,7 @@ def main():
elif args.config == "onnx_tosa": elif args.config == "onnx_tosa":
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa") config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
xfail_set = ONNX_TOSA_XFAIL_SET xfail_set = ONNX_TOSA_XFAIL_SET
crashing_set = set() crashing_set = ONNX_TOSA_CRASHING_SET
do_not_attempt = set( do_not_attempt = set(
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or [] args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []

View File

@ -1571,9 +1571,25 @@ TOSA_CRASHING_SET = {
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
} }
FX_IMPORTER_TOSA_CRASHING_SET = {
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"UpSampleNearest2d_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dStaticFactor_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"AtenTrilStaticModule_basic",
"AtenTrilWithNegDiagonalStaticModule_basic",
"AtenTrilWithPosDiagonalStaticModule_basic",
"ArgmaxKeepdimModule_basic", "ArgmaxKeepdimModule_basic",
"MeshgridIndexingIJ_basic", "MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic", "MeshgridIndexingXY_basic",
@ -2938,6 +2954,64 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
} }
FX_IMPORTER_TOSA_XFAIL_SET = { FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueModule_basic",
"AvgPool3dStaticModule_basic",
"Conv_Transpose1dModule_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic",
"MaxUnpool3dModule_basic",
"MultinomialModule2D_F32",
"MultinomialModule2D_basic",
"MultinomialModule_basic",
"ReduceAminSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAminmaxSingleDim_basic",
"ReduceAnyDimFloatModule_basic",
"RenormModuleFloat16_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic",
"TensorsConcatComplex64FloatModule_basic",
"TimeOutModule_basic",
"TrilIndicesAllZerosModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesNegativeOffsetModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesAllZerosModule_basic",
"TriuIndicesModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"TypeConversionUint8ToF32Module_basic",
"WeightNormInterfaceModule_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic",
@ -2960,7 +3034,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveMaxPool3dStatic_basic", "AdaptiveMaxPool3dStatic_basic",
"AddIntModule_basic", "AddIntModule_basic",
"AddFloatIntModule_basic", "AddFloatIntModule_basic",
"Add_MixPModule_basic",
"AllBoolFalseModule_basic", "AllBoolFalseModule_basic",
"AllBoolTrueModule_basic", "AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic", "AnyBoolFalseModule_basic",
@ -2987,7 +3060,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenFloatScalarModule_basic", "AtenFloatScalarModule_basic",
"AtenHannWindowPeriodicTrueModule_basic", "AtenHannWindowPeriodicTrueModule_basic",
"AtenHannWindowPeriodicFalseModule_basic", "AtenHannWindowPeriodicFalseModule_basic",
"AtenInstanceNormModule_basic",
"AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic", "AtenIntBoolOpModule_basic",
@ -3018,9 +3090,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"Aten_CastLongModule_basic", "Aten_CastLongModule_basic",
"Aten_EmbeddingBagExample_basic", "Aten_EmbeddingBagExample_basic",
"AvgPool1dFloatModule_basic", "AvgPool1dFloatModule_basic",
@ -3163,7 +3232,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic", "ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorFloatModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic", "ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
@ -3199,7 +3267,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseMishModule_basic", "ElementwiseMishModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorFloatModule_basic",
"ElementwisePowScalarModule_basic", "ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastModule_basic", "ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic",
@ -3220,14 +3287,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseSinhModule_basic", "ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic", "ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic", "ElementwiseTanModule_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"EmptyLikeMemoryFormatModule_basic", "EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype", "EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory", "EmptyLikeModule_falsePinMemory",
@ -3274,8 +3337,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"GridSamplerBasic2_basic", "GridSamplerBasic2_basic",
"GridSamplerBasic3_basic", "GridSamplerBasic3_basic",
"GridSamplerBasic4_basic", "GridSamplerBasic4_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"HBC_basic", "HBC_basic",
@ -3324,21 +3385,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexSelectTwoIdxModule_basic", "IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic", "IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic", "IndexSelectWholeTensorModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateStaticModule_scales_bilinear_align_corners",
@ -3347,9 +3394,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IntImplicitModule_basic", "IntImplicitModule_basic",
"IsFloatingPointFloat_True", "IsFloatingPointFloat_True",
"IsFloatingPointInt_False", "IsFloatingPointInt_False",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"LenStrModule_basic", "LenStrModule_basic",
"LinalgNormKeepDimComplexModule_basic", "LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic", "LinalgVectorNormComplexModule_basic",
@ -3358,7 +3402,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinspaceModule_basic", "LinspaceModule_basic",
"LinspaceOneSizeModule_basic", "LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic", "LinspaceTwoSizeModule_basic",
"LogSoftmaxIntModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic", "MatmulBroadcastBatchDim_basic",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
@ -3412,10 +3455,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NativeDropoutTrainModule_basic", "NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic", "NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormBackwardModule_basic", "NativeGroupNormBackwardModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormDynamicModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleDefaultDtype_basic",
@ -3506,11 +3545,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceL3NormKeepDimComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic",
"ReduceL3NormKeepDimModule_basic", "ReduceL3NormKeepDimModule_basic",
"ReduceMaxAllDims_basic", "ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic", "ReduceMaxFloatModule_basic",
"ReduceMaxKeepDim_basic",
"ReduceMaxSignedIntModule_basic", "ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic",
"ReduceMinAlongDimNegative_basic", "ReduceMinAlongDimNegative_basic",
@ -3601,8 +3637,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"SliceScatterStepVariationModule_basic", "SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic", "SliceScatterZeroDimModule_basic",
"SliceSizeTwoStepModule_basic", "SliceSizeTwoStepModule_basic",
"SoftmaxIntArgTypeF64Module_basic",
"SoftmaxIntNonNoneDtypeModule_basic",
"SoftplusModule_basic", "SoftplusModule_basic",
"SortIntListReverse_basic", "SortIntListReverse_basic",
"SortIntList_basic", "SortIntList_basic",
@ -3615,20 +3649,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"SplitDimStaticModule_basic", "SplitDimStaticModule_basic",
"SqrtIntConstantModule_basic", "SqrtIntConstantModule_basic",
"SqrtIntModule_basic", "SqrtIntModule_basic",
"StdBiasedModule_basic",
"StdCorrectionAllDimReduceModule_basic",
"StdCorrectionEmptyDimModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionLargeInputModule_basic",
"StdCorrectionModule_basic",
"StdCorrectionNoneModule_basic",
"StdCorrectionSingleDimReduceModule_basic",
"StdDimBiasedModule_basic",
"StdDimEmptyDimModule_basic",
"StdDimKeepDimFalseModule_basic",
"StdDimKeepDimTrueModule_basic",
"StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic", "SubFloatModule_basic",
"SubIntModule_basic", "SubIntModule_basic",
"TModuleRank0_basic", "TModuleRank0_basic",
@ -3665,8 +3685,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TraceUnsignedIntModule_empty", "TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic", "TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic", "TypeConversionI1ToI32Module_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UniformModule_basic", "UniformModule_basic",
"UniformNoCorrelationModule_basic", "UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic", "UniformStaticShapeModule_basic",
@ -3679,30 +3697,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic", "UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic", "UpSampleNearest2d_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarCorrectionKeepDimModule_basic",
"VarCorrectionLargeInputModule_basic",
"VarCorrectionModule_basic",
"VarCorrectionNoneModule_basic",
"VarCorrectionSingleDimReduceModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimEmptyDimModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimNoneDimModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarMeanBiasedModule_basic", "VarMeanBiasedModule_basic",
"VarMeanCorrectionModule_basic",
"VarMeanCorrectionNoneModule_basic", "VarMeanCorrectionNoneModule_basic",
"VarMeanDimBiasedModule_basic",
"VarMeanDimModule_basic",
"VarMeanUnbiasedModule_basic", "VarMeanUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"ZeroFloat32Module_basic", "ZeroFloat32Module_basic",
@ -3711,7 +3708,79 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ZerosLikeModule_falsePinMemory", "ZerosLikeModule_falsePinMemory",
} }
ONNX_TOSA_CRASHING_SET = {
"StdCorrectionEmptyDimModule_basic",
"StdDimEmptyDimModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarDimEmptyDimModule_basic",
"ViewSizeFromOtherTensor_basic",
}
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"ArgmaxKeepdimModule_basic",
"AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueModule_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"AvgPool3dStaticModule_basic",
"Conv_Transpose1dModule_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseNanToNumWithNoneModule_Basic",
"ElementwiseRad2DegIntModule_basic",
"ElementwiseRad2DegModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic",
"MaxUnpool3dModule_basic",
"MultinomialModule2D_F32",
"MultinomialModule2D_basic",
"MultinomialModule_basic",
"ReduceAmaxEmptyDim_basic",
"ReduceAminSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAminmaxSingleDim_basic",
"ReduceAnyDimFloatModule_basic",
"RenormModuleFloat16_basic",
"RenormModuleFloat32DynamicDims_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"ScatterAddStaticModule_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic",
"TensorsConcatComplex64FloatModule_basic",
"TimeOutModule_basic",
"TypeConversionUint8ToF32Module_basic",
"UnfoldModule_basic",
"WeightNormInterfaceModule_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
@ -3929,8 +3998,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAcoshModule_basic", "ElementwiseAcoshModule_basic",
"ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic", "ElementwiseAddScalarIntModule_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwiseAsinIntModule_basic", "ElementwiseAsinIntModule_basic",
"ElementwiseAsinModule_basic", "ElementwiseAsinModule_basic",
"ElementwiseAsinhIntModule_basic", "ElementwiseAsinhIntModule_basic",
@ -3951,7 +4018,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorPositiveModule_basic", "ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpModule_basic",
@ -3969,10 +4035,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic",
@ -3987,12 +4049,8 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseClampMaxModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic", "ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampModule_basic", "ElementwiseClampModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorInt8Module_basic", "ElementwiseClampTensorInt8Module_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseCosIntModule_basic", "ElementwiseCosIntModule_basic",
"ElementwiseCosModule_basic", "ElementwiseCosModule_basic",
"ElementwiseCoshIntModule_basic", "ElementwiseCoshIntModule_basic",
@ -4006,7 +4064,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseDivTensorIntegerModule_basic", "ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic", "ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic", "ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
@ -4030,7 +4087,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseGeIntScalarModule_basic", "ElementwiseGeIntScalarModule_basic",
"ElementwiseGeIntTensorModule_basic", "ElementwiseGeIntTensorModule_basic",
"ElementwiseGeMixedIntScalarModule_basic", "ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGtMixed2ScalarModule_basic", "ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseIsinfModule_basic", "ElementwiseIsinfModule_basic",
@ -4084,9 +4140,7 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarOtherModule_basic", "ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfModule_basic", "ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"ElementwiseWhereSelfModule_basic", "ElementwiseWhereSelfModule_basic",
"EmbeddingModule1DIndices_basic", "EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic", "EmbeddingModuleF16_basic",
@ -4144,8 +4198,6 @@ ONNX_TOSA_XFAIL_SET = {
"HBC_basic", "HBC_basic",
"HardTanhIntModule_basic", "HardTanhIntModule_basic",
"HardTanhModule_basic", "HardTanhModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"HardtanhBackward_basic", "HardtanhBackward_basic",
"IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic",
@ -4216,7 +4268,6 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorStaticNonContiguousWithNoneModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_scales_recompute_bilinear",
"IntFloatModule_basic", "IntFloatModule_basic",
"IntImplicitModule_basic", "IntImplicitModule_basic",

View File

@ -1356,3 +1356,20 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v
%1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32>
return %1 : !torch.vtensor<[1,16,270,480],f32> return %1 : !torch.vtensor<[1,16,270,480],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.tril$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],si32>) -> !torch.vtensor<[2,4],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32>
// CHECK: }
func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> {
%int0 = torch.constant.int 1
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
return %0 : !torch.vtensor<[2,4],si32>
}