mirror of https://github.com/llvm/torch-mlir
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3689/head
parent
b790061b69
commit
d4b5e05ac1
|
@ -22,6 +22,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
|
||||
|
@ -5385,6 +5386,114 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Template to create support tril mask tensor for aten.tril
|
||||
// legalization
|
||||
template <typename T>
|
||||
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t diagonal) {
|
||||
SmallVector<T> vec;
|
||||
|
||||
for (int64_t i = 0; i < h; i++) {
|
||||
for (int64_t j = 0; j < w; j++) {
|
||||
// Positive diagonal value includes as many diagonals above the main
|
||||
// diagonal, while negative diagonal value excludes as many diagonals
|
||||
// below the main diagonal.
|
||||
if (i >= j - diagonal) {
|
||||
vec.push_back(static_cast<T>(1));
|
||||
} else {
|
||||
vec.push_back(static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||
}
|
||||
|
||||
// Function to get tril mask tensor based on input type
|
||||
// for aten.tril legalization
|
||||
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t diagonal, Type type) {
|
||||
return TypeSwitch<Type, Value>(type)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
||||
case 32:
|
||||
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
||||
case 64:
|
||||
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
}
|
||||
|
||||
// Legalization for aten.tril
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||
AtenTrilOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
||||
// Rank below 2 not accepted
|
||||
auto selfRank = selfType.getRank();
|
||||
if (selfRank <= 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||
|
||||
if (!selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only static shapes are supported");
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
if (!resultType)
|
||||
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||
|
||||
// Get height, width of input tensor, and diagonal arg to create
|
||||
// a const mask tensor to multiply with input.
|
||||
// This mask tensor has the same height and width of input tensor
|
||||
// and consists of 1's for the lower triangle part and 0's for the rest.
|
||||
// For example, with h=4, w=6, diagonal=1:
|
||||
// tensor([[1, 1, 0, 0, 0, 0],
|
||||
// [1, 1, 1, 0, 0, 0],
|
||||
// [1, 1, 1, 1, 0, 0],
|
||||
// [1, 1, 1, 1, 1, 0]])
|
||||
auto selfShape = selfType.getShape();
|
||||
int64_t h = selfShape[selfRank - 2];
|
||||
int64_t w = selfShape[selfRank - 1];
|
||||
int64_t diagonal;
|
||||
|
||||
if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal)))
|
||||
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||
|
||||
// Define shape for mask tensor based on rank
|
||||
SmallVector<int64_t> constShape;
|
||||
for (auto i = 0; i < selfRank - 2; i++)
|
||||
constShape.push_back(1);
|
||||
constShape.push_back(h);
|
||||
constShape.push_back(w);
|
||||
|
||||
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
||||
resultType.getElementType());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||
/*shift=*/0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -5638,6 +5747,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -58,8 +58,10 @@ from .xfail_sets import (
|
|||
FX_IMPORTER_CRASHING_SET,
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET,
|
||||
FX_IMPORTER_STABLEHLO_CRASHING_SET,
|
||||
FX_IMPORTER_TOSA_CRASHING_SET,
|
||||
FX_IMPORTER_TOSA_XFAIL_SET,
|
||||
ONNX_TOSA_XFAIL_SET,
|
||||
ONNX_TOSA_CRASHING_SET,
|
||||
)
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
|
@ -191,7 +193,7 @@ def main():
|
|||
elif args.config == "fx_importer_tosa":
|
||||
config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
|
||||
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
|
||||
crashing_set = set()
|
||||
crashing_set = FX_IMPORTER_TOSA_CRASHING_SET
|
||||
elif args.config == "torchdynamo":
|
||||
# TODO: Enanble runtime verification and extend crashing set.
|
||||
config = TorchDynamoTestConfig(
|
||||
|
@ -206,7 +208,7 @@ def main():
|
|||
elif args.config == "onnx_tosa":
|
||||
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
|
||||
xfail_set = ONNX_TOSA_XFAIL_SET
|
||||
crashing_set = set()
|
||||
crashing_set = ONNX_TOSA_CRASHING_SET
|
||||
|
||||
do_not_attempt = set(
|
||||
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
|
||||
|
|
|
@ -1571,9 +1571,25 @@ TOSA_CRASHING_SET = {
|
|||
"IndexTensorNegativeIndexModule_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"UpSampleNearest2d_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AtenTrilStaticModule_basic",
|
||||
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||
"ArgmaxKeepdimModule_basic",
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
|
@ -2938,6 +2954,64 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64Module_basic",
|
||||
"AtenKthvalueKeepDimModule_basic",
|
||||
"AtenKthvalueModule_basic",
|
||||
"AvgPool3dStaticModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
"Conv_Transpose1dStaticModule_basic",
|
||||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRreluEvalModule_basic",
|
||||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
"MaxUnpool3dModule_basic",
|
||||
"MultinomialModule2D_F32",
|
||||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule_basic",
|
||||
"ReduceAminSingleDim_basic",
|
||||
"ReduceAminmaxAllDims_basic",
|
||||
"ReduceAminmaxSingleDim_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
"TensorsConcatComplex64FloatModule_basic",
|
||||
"TimeOutModule_basic",
|
||||
"TrilIndicesAllZerosModule_basic",
|
||||
"TrilIndicesModule_basic",
|
||||
"TrilIndicesNegativeOffsetModule_basic",
|
||||
"TrilIndicesOfssetGreaterThanRowModule_basic",
|
||||
"TriuIndicesAllZerosModule_basic",
|
||||
"TriuIndicesModule_basic",
|
||||
"TriuIndicesNegativeOffsetModule_basic",
|
||||
"TypeConversionUint8ToF32Module_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
|
@ -2960,7 +3034,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AdaptiveMaxPool3dStatic_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"Add_MixPModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
|
@ -2987,7 +3060,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenFloatScalarModule_basic",
|
||||
"AtenHannWindowPeriodicTrueModule_basic",
|
||||
"AtenHannWindowPeriodicFalseModule_basic",
|
||||
"AtenInstanceNormModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
|
@ -3018,9 +3090,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
"AtenTrilModule_basic",
|
||||
"AtenTrilWithNegDiagonalModule_basic",
|
||||
"AtenTrilWithPosDiagonalModule_basic",
|
||||
"Aten_CastLongModule_basic",
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
"AvgPool1dFloatModule_basic",
|
||||
|
@ -3163,7 +3232,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorFloatModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
|
@ -3199,7 +3267,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseMishModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorFloatModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
"ElementwisePowTensorBroadcastModule_basic",
|
||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||
|
@ -3220,14 +3287,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
"ElementwiseTernaryModule_basic",
|
||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
"EmptyLikeMemoryFormatModule_basic",
|
||||
"EmptyLikeModule_defaultDtype",
|
||||
"EmptyLikeModule_falsePinMemory",
|
||||
|
@ -3274,8 +3337,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
"GridSamplerBasic4_basic",
|
||||
"GroupNormModule_basic",
|
||||
"GroupNormNoWeightAndBiasModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"HBC_basic",
|
||||
|
@ -3324,21 +3385,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguous_basic",
|
||||
"IndexTensorMultiInputOneDim_basic",
|
||||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
|
@ -3347,9 +3394,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IntImplicitModule_basic",
|
||||
"IsFloatingPointFloat_True",
|
||||
"IsFloatingPointInt_False",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"LenStrModule_basic",
|
||||
"LinalgNormKeepDimComplexModule_basic",
|
||||
"LinalgVectorNormComplexModule_basic",
|
||||
|
@ -3358,7 +3402,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"LogSoftmaxIntModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MatmulBroadcastBatchDim_basic",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
@ -3412,10 +3455,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"NativeGroupNormModule_basic",
|
||||
"NativeLayerNormDynamicModule_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NewEmptyModuleDefaultDtype_basic",
|
||||
|
@ -3506,11 +3545,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReduceL3NormKeepDimComplexModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"ReduceMaxAllDims_basic",
|
||||
"ReduceMaxAlongDimNegative_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMaxAlongDim_basic",
|
||||
"ReduceMaxFloatModule_basic",
|
||||
"ReduceMaxKeepDim_basic",
|
||||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
"ReduceMinAlongDimNegative_basic",
|
||||
|
@ -3601,8 +3637,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"SliceScatterStepVariationModule_basic",
|
||||
"SliceScatterZeroDimModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SoftmaxIntArgTypeF64Module_basic",
|
||||
"SoftmaxIntNonNoneDtypeModule_basic",
|
||||
"SoftplusModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
|
@ -3615,20 +3649,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"SplitDimStaticModule_basic",
|
||||
"SqrtIntConstantModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"StdCorrectionAllDimReduceModule_basic",
|
||||
"StdCorrectionEmptyDimModule_basic",
|
||||
"StdCorrectionKeepDimModule_basic",
|
||||
"StdCorrectionLargeInputModule_basic",
|
||||
"StdCorrectionModule_basic",
|
||||
"StdCorrectionNoneModule_basic",
|
||||
"StdCorrectionSingleDimReduceModule_basic",
|
||||
"StdDimBiasedModule_basic",
|
||||
"StdDimEmptyDimModule_basic",
|
||||
"StdDimKeepDimFalseModule_basic",
|
||||
"StdDimKeepDimTrueModule_basic",
|
||||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TModuleRank0_basic",
|
||||
|
@ -3665,8 +3685,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"TraceUnsignedIntModule_empty",
|
||||
"TypeConversionI1ToF64Module_basic",
|
||||
"TypeConversionI1ToI32Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformNoCorrelationModule_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
|
@ -3679,30 +3697,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2d_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarCorrectionAllDimReduceModule_basic",
|
||||
"VarCorrectionEmptyDimModule_basic",
|
||||
"VarCorrectionKeepDimModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
"VarCorrectionModule_basic",
|
||||
"VarCorrectionNoneModule_basic",
|
||||
"VarCorrectionSingleDimReduceModule_basic",
|
||||
"VarDimAllDimReduceModule_basic",
|
||||
"VarDimBiasedModule_basic",
|
||||
"VarDimEmptyDimModule_basic",
|
||||
"VarDimModule_basic",
|
||||
"VarDimMultiDimModule_basic",
|
||||
"VarDimNegativeModule_basic",
|
||||
"VarDimNoneDimModule_basic",
|
||||
"VarDimSingleDimModule_basic",
|
||||
"VarDimUnbiasedModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
"VarMeanCorrectionModule_basic",
|
||||
"VarMeanCorrectionNoneModule_basic",
|
||||
"VarMeanDimBiasedModule_basic",
|
||||
"VarMeanDimModule_basic",
|
||||
"VarMeanUnbiasedModule_basic",
|
||||
"VarUnbiasedModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ZeroFloat32Module_basic",
|
||||
|
@ -3711,7 +3708,79 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ZerosLikeModule_falsePinMemory",
|
||||
}
|
||||
|
||||
ONNX_TOSA_CRASHING_SET = {
|
||||
"StdCorrectionEmptyDimModule_basic",
|
||||
"StdDimEmptyDimModule_basic",
|
||||
"VarCorrectionEmptyDimModule_basic",
|
||||
"VarDimEmptyDimModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"ArgmaxKeepdimModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64Module_basic",
|
||||
"AtenKthvalueKeepDimModule_basic",
|
||||
"AtenKthvalueModule_basic",
|
||||
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||
"AvgPool3dStaticModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
"Conv_Transpose1dStaticModule_basic",
|
||||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseNanToNumWithNoneModule_Basic",
|
||||
"ElementwiseRad2DegIntModule_basic",
|
||||
"ElementwiseRad2DegModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
"MaxUnpool3dModule_basic",
|
||||
"MultinomialModule2D_F32",
|
||||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule_basic",
|
||||
"ReduceAmaxEmptyDim_basic",
|
||||
"ReduceAminSingleDim_basic",
|
||||
"ReduceAminmaxAllDims_basic",
|
||||
"ReduceAminmaxSingleDim_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
"RenormModuleFloat32DynamicDims_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
"TensorsConcatComplex64FloatModule_basic",
|
||||
"TimeOutModule_basic",
|
||||
"TypeConversionUint8ToF32Module_basic",
|
||||
"UnfoldModule_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
|
@ -3929,8 +3998,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAcoshModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseAndScalarModule_basic",
|
||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAsinModule_basic",
|
||||
"ElementwiseAsinhIntModule_basic",
|
||||
|
@ -3951,7 +4018,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||
"ElementwiseAtenIsinfOpModule_basic",
|
||||
"ElementwiseAtenIsneginfOpModule_basic",
|
||||
"ElementwiseAtenIsposinfOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
|
@ -3969,10 +4035,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseAndModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||
|
@ -3987,12 +4049,8 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
"ElementwiseClampTensorInt8Module_basic",
|
||||
"ElementwiseClampTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
|
@ -4006,7 +4064,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseDivTensorIntegerModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
|
@ -4030,7 +4087,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseGeIntScalarModule_basic",
|
||||
"ElementwiseGeIntTensorModule_basic",
|
||||
"ElementwiseGeMixedIntScalarModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
"ElementwiseIsinfModule_basic",
|
||||
|
@ -4084,9 +4140,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
"ElementwiseWhereSelfModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
|
@ -4144,8 +4198,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"HBC_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
"HardsigmoidRandomModule_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
|
@ -4216,7 +4268,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
|
|
|
@ -1356,3 +1356,20 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v
|
|||
%1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32>
|
||||
return %1 : !torch.vtensor<[1,16,270,480],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.tril$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],si32>) -> !torch.vtensor<[2,4],si32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> {
|
||||
%int0 = torch.constant.int 1
|
||||
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
|
||||
return %0 : !torch.vtensor<[2,4],si32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue