mirror of https://github.com/llvm/torch-mlir
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3689/head
parent
b790061b69
commit
d4b5e05ac1
|
@ -22,6 +22,7 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
@ -5385,6 +5386,114 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Template to create support tril mask tensor for aten.tril
|
||||||
|
// legalization
|
||||||
|
template <typename T>
|
||||||
|
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
int64_t diagonal) {
|
||||||
|
SmallVector<T> vec;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < h; i++) {
|
||||||
|
for (int64_t j = 0; j < w; j++) {
|
||||||
|
// Positive diagonal value includes as many diagonals above the main
|
||||||
|
// diagonal, while negative diagonal value excludes as many diagonals
|
||||||
|
// below the main diagonal.
|
||||||
|
if (i >= j - diagonal) {
|
||||||
|
vec.push_back(static_cast<T>(1));
|
||||||
|
} else {
|
||||||
|
vec.push_back(static_cast<T>(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get tril mask tensor based on input type
|
||||||
|
// for aten.tril legalization
|
||||||
|
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
int64_t diagonal, Type type) {
|
||||||
|
return TypeSwitch<Type, Value>(type)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
||||||
|
case 32:
|
||||||
|
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
||||||
|
case 64:
|
||||||
|
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.tril
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
AtenTrilOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a ranked tensor type
|
||||||
|
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
|
// Rank below 2 not accepted
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
if (selfRank <= 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||||
|
|
||||||
|
if (!selfType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only static shapes are supported");
|
||||||
|
|
||||||
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
if (!resultType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||||
|
|
||||||
|
// Get height, width of input tensor, and diagonal arg to create
|
||||||
|
// a const mask tensor to multiply with input.
|
||||||
|
// This mask tensor has the same height and width of input tensor
|
||||||
|
// and consists of 1's for the lower triangle part and 0's for the rest.
|
||||||
|
// For example, with h=4, w=6, diagonal=1:
|
||||||
|
// tensor([[1, 1, 0, 0, 0, 0],
|
||||||
|
// [1, 1, 1, 0, 0, 0],
|
||||||
|
// [1, 1, 1, 1, 0, 0],
|
||||||
|
// [1, 1, 1, 1, 1, 0]])
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
int64_t h = selfShape[selfRank - 2];
|
||||||
|
int64_t w = selfShape[selfRank - 1];
|
||||||
|
int64_t diagonal;
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||||
|
|
||||||
|
// Define shape for mask tensor based on rank
|
||||||
|
SmallVector<int64_t> constShape;
|
||||||
|
for (auto i = 0; i < selfRank - 2; i++)
|
||||||
|
constShape.push_back(1);
|
||||||
|
constShape.push_back(h);
|
||||||
|
constShape.push_back(w);
|
||||||
|
|
||||||
|
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
||||||
|
resultType.getElementType());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -5638,6 +5747,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -58,8 +58,10 @@ from .xfail_sets import (
|
||||||
FX_IMPORTER_CRASHING_SET,
|
FX_IMPORTER_CRASHING_SET,
|
||||||
FX_IMPORTER_STABLEHLO_XFAIL_SET,
|
FX_IMPORTER_STABLEHLO_XFAIL_SET,
|
||||||
FX_IMPORTER_STABLEHLO_CRASHING_SET,
|
FX_IMPORTER_STABLEHLO_CRASHING_SET,
|
||||||
|
FX_IMPORTER_TOSA_CRASHING_SET,
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET,
|
FX_IMPORTER_TOSA_XFAIL_SET,
|
||||||
ONNX_TOSA_XFAIL_SET,
|
ONNX_TOSA_XFAIL_SET,
|
||||||
|
ONNX_TOSA_CRASHING_SET,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
|
@ -191,7 +193,7 @@ def main():
|
||||||
elif args.config == "fx_importer_tosa":
|
elif args.config == "fx_importer_tosa":
|
||||||
config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
|
config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
|
||||||
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
|
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
|
||||||
crashing_set = set()
|
crashing_set = FX_IMPORTER_TOSA_CRASHING_SET
|
||||||
elif args.config == "torchdynamo":
|
elif args.config == "torchdynamo":
|
||||||
# TODO: Enanble runtime verification and extend crashing set.
|
# TODO: Enanble runtime verification and extend crashing set.
|
||||||
config = TorchDynamoTestConfig(
|
config = TorchDynamoTestConfig(
|
||||||
|
@ -206,7 +208,7 @@ def main():
|
||||||
elif args.config == "onnx_tosa":
|
elif args.config == "onnx_tosa":
|
||||||
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
|
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
|
||||||
xfail_set = ONNX_TOSA_XFAIL_SET
|
xfail_set = ONNX_TOSA_XFAIL_SET
|
||||||
crashing_set = set()
|
crashing_set = ONNX_TOSA_CRASHING_SET
|
||||||
|
|
||||||
do_not_attempt = set(
|
do_not_attempt = set(
|
||||||
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
|
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
|
||||||
|
|
|
@ -1571,9 +1571,25 @@ TOSA_CRASHING_SET = {
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
|
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||||
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
|
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||||
|
"UpSampleNearest2d_basic",
|
||||||
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
|
}
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"AtenTrilStaticModule_basic",
|
||||||
|
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||||
|
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||||
"ArgmaxKeepdimModule_basic",
|
"ArgmaxKeepdimModule_basic",
|
||||||
"MeshgridIndexingIJ_basic",
|
"MeshgridIndexingIJ_basic",
|
||||||
"MeshgridIndexingXY_basic",
|
"MeshgridIndexingXY_basic",
|
||||||
|
@ -2938,6 +2954,64 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
|
"AtenIntMM_basic",
|
||||||
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
"AtenKthvalueFloat64Module_basic",
|
||||||
|
"AtenKthvalueKeepDimModule_basic",
|
||||||
|
"AtenKthvalueModule_basic",
|
||||||
|
"AvgPool3dStaticModule_basic",
|
||||||
|
"Conv_Transpose1dModule_basic",
|
||||||
|
"Conv_Transpose1dStaticModule_basic",
|
||||||
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
|
"Conv_Transpose3dModule_basic",
|
||||||
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRreluEvalModule_basic",
|
||||||
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluTrainModule_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
|
"MaskedScatterStaticBasic_basic",
|
||||||
|
"MaxUnpool3dModulePad0_basic",
|
||||||
|
"MaxUnpool3dModule_basic",
|
||||||
|
"MultinomialModule2D_F32",
|
||||||
|
"MultinomialModule2D_basic",
|
||||||
|
"MultinomialModule_basic",
|
||||||
|
"ReduceAminSingleDim_basic",
|
||||||
|
"ReduceAminmaxAllDims_basic",
|
||||||
|
"ReduceAminmaxSingleDim_basic",
|
||||||
|
"ReduceAnyDimFloatModule_basic",
|
||||||
|
"RenormModuleFloat16_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
|
"ScatterAddStaticModule_basic",
|
||||||
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
|
"TensorsConcatComplex128IntModule_basic",
|
||||||
|
"TensorsConcatComplex64FloatModule_basic",
|
||||||
|
"TimeOutModule_basic",
|
||||||
|
"TrilIndicesAllZerosModule_basic",
|
||||||
|
"TrilIndicesModule_basic",
|
||||||
|
"TrilIndicesNegativeOffsetModule_basic",
|
||||||
|
"TrilIndicesOfssetGreaterThanRowModule_basic",
|
||||||
|
"TriuIndicesAllZerosModule_basic",
|
||||||
|
"TriuIndicesModule_basic",
|
||||||
|
"TriuIndicesNegativeOffsetModule_basic",
|
||||||
|
"TypeConversionUint8ToF32Module_basic",
|
||||||
|
"WeightNormInterfaceModule_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
|
@ -2960,7 +3034,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AdaptiveMaxPool3dStatic_basic",
|
"AdaptiveMaxPool3dStatic_basic",
|
||||||
"AddIntModule_basic",
|
"AddIntModule_basic",
|
||||||
"AddFloatIntModule_basic",
|
"AddFloatIntModule_basic",
|
||||||
"Add_MixPModule_basic",
|
|
||||||
"AllBoolFalseModule_basic",
|
"AllBoolFalseModule_basic",
|
||||||
"AllBoolTrueModule_basic",
|
"AllBoolTrueModule_basic",
|
||||||
"AnyBoolFalseModule_basic",
|
"AnyBoolFalseModule_basic",
|
||||||
|
@ -2987,7 +3060,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenFloatScalarModule_basic",
|
"AtenFloatScalarModule_basic",
|
||||||
"AtenHannWindowPeriodicTrueModule_basic",
|
"AtenHannWindowPeriodicTrueModule_basic",
|
||||||
"AtenHannWindowPeriodicFalseModule_basic",
|
"AtenHannWindowPeriodicFalseModule_basic",
|
||||||
"AtenInstanceNormModule_basic",
|
|
||||||
"AtenIntBoolOpConstFalseModule_basic",
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
"AtenIntBoolOpConstTrueModule_basic",
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
"AtenIntBoolOpModule_basic",
|
"AtenIntBoolOpModule_basic",
|
||||||
|
@ -3018,9 +3090,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
"AtenTrilModule_basic",
|
|
||||||
"AtenTrilWithNegDiagonalModule_basic",
|
|
||||||
"AtenTrilWithPosDiagonalModule_basic",
|
|
||||||
"Aten_CastLongModule_basic",
|
"Aten_CastLongModule_basic",
|
||||||
"Aten_EmbeddingBagExample_basic",
|
"Aten_EmbeddingBagExample_basic",
|
||||||
"AvgPool1dFloatModule_basic",
|
"AvgPool1dFloatModule_basic",
|
||||||
|
@ -3163,7 +3232,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
"ElementwiseDivTensorFloatModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||||
|
@ -3199,7 +3267,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseMishModule_basic",
|
"ElementwiseMishModule_basic",
|
||||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwiseMulTensorFloatModule_basic",
|
|
||||||
"ElementwisePowScalarModule_basic",
|
"ElementwisePowScalarModule_basic",
|
||||||
"ElementwisePowTensorBroadcastModule_basic",
|
"ElementwisePowTensorBroadcastModule_basic",
|
||||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||||
|
@ -3220,14 +3287,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseSinhModule_basic",
|
"ElementwiseSinhModule_basic",
|
||||||
"ElementwiseTanIntModule_basic",
|
"ElementwiseTanIntModule_basic",
|
||||||
"ElementwiseTanModule_basic",
|
"ElementwiseTanModule_basic",
|
||||||
"ElementwiseTernaryModule_basic",
|
|
||||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseWhereScalarOtherModule_basic",
|
|
||||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||||
"ElementwiseWhereScalarSelfModule_basic",
|
|
||||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
|
||||||
"EmptyLikeMemoryFormatModule_basic",
|
"EmptyLikeMemoryFormatModule_basic",
|
||||||
"EmptyLikeModule_defaultDtype",
|
"EmptyLikeModule_defaultDtype",
|
||||||
"EmptyLikeModule_falsePinMemory",
|
"EmptyLikeModule_falsePinMemory",
|
||||||
|
@ -3274,8 +3337,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"GridSamplerBasic2_basic",
|
"GridSamplerBasic2_basic",
|
||||||
"GridSamplerBasic3_basic",
|
"GridSamplerBasic3_basic",
|
||||||
"GridSamplerBasic4_basic",
|
"GridSamplerBasic4_basic",
|
||||||
"GroupNormModule_basic",
|
|
||||||
"GroupNormNoWeightAndBiasModule_basic",
|
|
||||||
"GtFloatIntModule_basic",
|
"GtFloatIntModule_basic",
|
||||||
"GtIntModule_basic",
|
"GtIntModule_basic",
|
||||||
"HBC_basic",
|
"HBC_basic",
|
||||||
|
@ -3324,21 +3385,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"IndexSelectTwoIdxModule_basic",
|
"IndexSelectTwoIdxModule_basic",
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
"IndexSelectWholeTensorModule_basic",
|
"IndexSelectWholeTensorModule_basic",
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguous_basic",
|
|
||||||
"IndexTensorMultiInputOneDim_basic",
|
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
|
||||||
"IndexTensorMultiInput_basic",
|
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"IndexTensorSelectDimModule_basic",
|
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||||
|
@ -3347,9 +3394,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"IntImplicitModule_basic",
|
"IntImplicitModule_basic",
|
||||||
"IsFloatingPointFloat_True",
|
"IsFloatingPointFloat_True",
|
||||||
"IsFloatingPointInt_False",
|
"IsFloatingPointInt_False",
|
||||||
"LayerNormLastDimModule_basic",
|
|
||||||
"LayerNormModule_basic",
|
|
||||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
|
||||||
"LenStrModule_basic",
|
"LenStrModule_basic",
|
||||||
"LinalgNormKeepDimComplexModule_basic",
|
"LinalgNormKeepDimComplexModule_basic",
|
||||||
"LinalgVectorNormComplexModule_basic",
|
"LinalgVectorNormComplexModule_basic",
|
||||||
|
@ -3358,7 +3402,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"LinspaceModule_basic",
|
"LinspaceModule_basic",
|
||||||
"LinspaceOneSizeModule_basic",
|
"LinspaceOneSizeModule_basic",
|
||||||
"LinspaceTwoSizeModule_basic",
|
"LinspaceTwoSizeModule_basic",
|
||||||
"LogSoftmaxIntModule_basic",
|
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"MatmulBroadcastBatchDim_basic",
|
"MatmulBroadcastBatchDim_basic",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
@ -3412,10 +3455,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"NativeDropoutTrainModule_basic",
|
"NativeDropoutTrainModule_basic",
|
||||||
"NativeDropoutTrainStaticShapeModule_basic",
|
"NativeDropoutTrainStaticShapeModule_basic",
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"NativeGroupNormModule_basic",
|
|
||||||
"NativeLayerNormDynamicModule_basic",
|
|
||||||
"NativeLayerNormModule4D_basic",
|
|
||||||
"NativeLayerNormModule_basic",
|
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
"NewEmptyModuleDefaultDtype_basic",
|
"NewEmptyModuleDefaultDtype_basic",
|
||||||
|
@ -3506,11 +3545,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReduceL3NormKeepDimComplexModule_basic",
|
"ReduceL3NormKeepDimComplexModule_basic",
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
"ReduceMaxAllDims_basic",
|
"ReduceMaxAllDims_basic",
|
||||||
"ReduceMaxAlongDimNegative_basic",
|
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMaxAlongDim_basic",
|
|
||||||
"ReduceMaxFloatModule_basic",
|
"ReduceMaxFloatModule_basic",
|
||||||
"ReduceMaxKeepDim_basic",
|
|
||||||
"ReduceMaxSignedIntModule_basic",
|
"ReduceMaxSignedIntModule_basic",
|
||||||
"ReduceMaxUnsignedIntModule_basic",
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
"ReduceMinAlongDimNegative_basic",
|
"ReduceMinAlongDimNegative_basic",
|
||||||
|
@ -3601,8 +3637,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"SliceScatterStepVariationModule_basic",
|
"SliceScatterStepVariationModule_basic",
|
||||||
"SliceScatterZeroDimModule_basic",
|
"SliceScatterZeroDimModule_basic",
|
||||||
"SliceSizeTwoStepModule_basic",
|
"SliceSizeTwoStepModule_basic",
|
||||||
"SoftmaxIntArgTypeF64Module_basic",
|
|
||||||
"SoftmaxIntNonNoneDtypeModule_basic",
|
|
||||||
"SoftplusModule_basic",
|
"SoftplusModule_basic",
|
||||||
"SortIntListReverse_basic",
|
"SortIntListReverse_basic",
|
||||||
"SortIntList_basic",
|
"SortIntList_basic",
|
||||||
|
@ -3615,20 +3649,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"SplitDimStaticModule_basic",
|
"SplitDimStaticModule_basic",
|
||||||
"SqrtIntConstantModule_basic",
|
"SqrtIntConstantModule_basic",
|
||||||
"SqrtIntModule_basic",
|
"SqrtIntModule_basic",
|
||||||
"StdBiasedModule_basic",
|
|
||||||
"StdCorrectionAllDimReduceModule_basic",
|
|
||||||
"StdCorrectionEmptyDimModule_basic",
|
|
||||||
"StdCorrectionKeepDimModule_basic",
|
|
||||||
"StdCorrectionLargeInputModule_basic",
|
|
||||||
"StdCorrectionModule_basic",
|
|
||||||
"StdCorrectionNoneModule_basic",
|
|
||||||
"StdCorrectionSingleDimReduceModule_basic",
|
|
||||||
"StdDimBiasedModule_basic",
|
|
||||||
"StdDimEmptyDimModule_basic",
|
|
||||||
"StdDimKeepDimFalseModule_basic",
|
|
||||||
"StdDimKeepDimTrueModule_basic",
|
|
||||||
"StdDimNoneDimModule_basic",
|
|
||||||
"StdUnbiasedModule_basic",
|
|
||||||
"SubFloatModule_basic",
|
"SubFloatModule_basic",
|
||||||
"SubIntModule_basic",
|
"SubIntModule_basic",
|
||||||
"TModuleRank0_basic",
|
"TModuleRank0_basic",
|
||||||
|
@ -3665,8 +3685,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"TraceUnsignedIntModule_empty",
|
"TraceUnsignedIntModule_empty",
|
||||||
"TypeConversionI1ToF64Module_basic",
|
"TypeConversionI1ToF64Module_basic",
|
||||||
"TypeConversionI1ToI32Module_basic",
|
"TypeConversionI1ToI32Module_basic",
|
||||||
"UnbindIntGetItem_Module_basic",
|
|
||||||
"UnbindIntListUnpack_Module_basic",
|
|
||||||
"UniformModule_basic",
|
"UniformModule_basic",
|
||||||
"UniformNoCorrelationModule_basic",
|
"UniformNoCorrelationModule_basic",
|
||||||
"UniformStaticShapeModule_basic",
|
"UniformStaticShapeModule_basic",
|
||||||
|
@ -3679,30 +3697,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
"UpSampleNearest2d_basic",
|
"UpSampleNearest2d_basic",
|
||||||
"VarBiasedModule_basic",
|
|
||||||
"VarCorrectionAllDimReduceModule_basic",
|
|
||||||
"VarCorrectionEmptyDimModule_basic",
|
|
||||||
"VarCorrectionKeepDimModule_basic",
|
|
||||||
"VarCorrectionLargeInputModule_basic",
|
|
||||||
"VarCorrectionModule_basic",
|
|
||||||
"VarCorrectionNoneModule_basic",
|
|
||||||
"VarCorrectionSingleDimReduceModule_basic",
|
|
||||||
"VarDimAllDimReduceModule_basic",
|
|
||||||
"VarDimBiasedModule_basic",
|
|
||||||
"VarDimEmptyDimModule_basic",
|
|
||||||
"VarDimModule_basic",
|
|
||||||
"VarDimMultiDimModule_basic",
|
|
||||||
"VarDimNegativeModule_basic",
|
|
||||||
"VarDimNoneDimModule_basic",
|
|
||||||
"VarDimSingleDimModule_basic",
|
|
||||||
"VarDimUnbiasedModule_basic",
|
|
||||||
"VarMeanBiasedModule_basic",
|
"VarMeanBiasedModule_basic",
|
||||||
"VarMeanCorrectionModule_basic",
|
|
||||||
"VarMeanCorrectionNoneModule_basic",
|
"VarMeanCorrectionNoneModule_basic",
|
||||||
"VarMeanDimBiasedModule_basic",
|
|
||||||
"VarMeanDimModule_basic",
|
|
||||||
"VarMeanUnbiasedModule_basic",
|
"VarMeanUnbiasedModule_basic",
|
||||||
"VarUnbiasedModule_basic",
|
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"ZeroFloat32Module_basic",
|
"ZeroFloat32Module_basic",
|
||||||
|
@ -3711,7 +3708,79 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ZerosLikeModule_falsePinMemory",
|
"ZerosLikeModule_falsePinMemory",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ONNX_TOSA_CRASHING_SET = {
|
||||||
|
"StdCorrectionEmptyDimModule_basic",
|
||||||
|
"StdDimEmptyDimModule_basic",
|
||||||
|
"VarCorrectionEmptyDimModule_basic",
|
||||||
|
"VarDimEmptyDimModule_basic",
|
||||||
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||||
|
"ArgmaxKeepdimModule_basic",
|
||||||
|
"AtenIntMM_basic",
|
||||||
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
"AtenKthvalueFloat64Module_basic",
|
||||||
|
"AtenKthvalueKeepDimModule_basic",
|
||||||
|
"AtenKthvalueModule_basic",
|
||||||
|
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||||
|
"AvgPool3dStaticModule_basic",
|
||||||
|
"Conv_Transpose1dModule_basic",
|
||||||
|
"Conv_Transpose1dStaticModule_basic",
|
||||||
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
|
"Conv_Transpose3dModule_basic",
|
||||||
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
|
"EinsumStaticModule_basic",
|
||||||
|
"ElementwiseFmaxModule_basic",
|
||||||
|
"ElementwiseFminModule_basic",
|
||||||
|
"ElementwiseGeluApproximateTanhModule_basic",
|
||||||
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
|
"ElementwiseNanToNumWithNoneModule_Basic",
|
||||||
|
"ElementwiseRad2DegIntModule_basic",
|
||||||
|
"ElementwiseRad2DegModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRreluTrainModule_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
|
"MaskedScatterStaticBasic_basic",
|
||||||
|
"MaxUnpool3dModulePad0_basic",
|
||||||
|
"MaxUnpool3dModule_basic",
|
||||||
|
"MultinomialModule2D_F32",
|
||||||
|
"MultinomialModule2D_basic",
|
||||||
|
"MultinomialModule_basic",
|
||||||
|
"ReduceAmaxEmptyDim_basic",
|
||||||
|
"ReduceAminSingleDim_basic",
|
||||||
|
"ReduceAminmaxAllDims_basic",
|
||||||
|
"ReduceAminmaxSingleDim_basic",
|
||||||
|
"ReduceAnyDimFloatModule_basic",
|
||||||
|
"RenormModuleFloat16_basic",
|
||||||
|
"RenormModuleFloat32DynamicDims_basic",
|
||||||
|
"RenormModuleFloat32NegativeDim_basic",
|
||||||
|
"RenormModuleFloat32_basic",
|
||||||
|
"ScatterAddStaticModule_basic",
|
||||||
|
"TensorSplitSections_GetItemModule_basic",
|
||||||
|
"TensorSplitSections_ListUnpackModule_basic",
|
||||||
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
|
"TensorsConcatComplex128IntModule_basic",
|
||||||
|
"TensorsConcatComplex64FloatModule_basic",
|
||||||
|
"TimeOutModule_basic",
|
||||||
|
"TypeConversionUint8ToF32Module_basic",
|
||||||
|
"UnfoldModule_basic",
|
||||||
|
"WeightNormInterfaceModule_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
@ -3929,8 +3998,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAcoshModule_basic",
|
"ElementwiseAcoshModule_basic",
|
||||||
"ElementwiseAddScalarInt64Module_basic",
|
"ElementwiseAddScalarInt64Module_basic",
|
||||||
"ElementwiseAddScalarIntModule_basic",
|
"ElementwiseAddScalarIntModule_basic",
|
||||||
"ElementwiseAndScalarModule_basic",
|
|
||||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAsinModule_basic",
|
"ElementwiseAsinModule_basic",
|
||||||
"ElementwiseAsinhIntModule_basic",
|
"ElementwiseAsinhIntModule_basic",
|
||||||
|
@ -3951,7 +4018,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||||
"ElementwiseAtenIsinfOpModule_basic",
|
|
||||||
"ElementwiseAtenIsneginfOpModule_basic",
|
"ElementwiseAtenIsneginfOpModule_basic",
|
||||||
"ElementwiseAtenIsposinfOpModule_basic",
|
"ElementwiseAtenIsposinfOpModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||||
|
@ -3969,10 +4035,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
"ElementwiseBitwiseAndModule_basic",
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
|
||||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
|
||||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||||
|
@ -3987,12 +4049,8 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
"ElementwiseClampMinModule_basic",
|
"ElementwiseClampMinModule_basic",
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampMinTensorIntModule_basic",
|
|
||||||
"ElementwiseClampModule_basic",
|
"ElementwiseClampModule_basic",
|
||||||
"ElementwiseClampTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampTensorInt8Module_basic",
|
"ElementwiseClampTensorInt8Module_basic",
|
||||||
"ElementwiseClampTensorIntModule_basic",
|
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseCosModule_basic",
|
"ElementwiseCosModule_basic",
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
|
@ -4006,7 +4064,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseDivTensorIntegerModule_basic",
|
"ElementwiseDivTensorIntegerModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||||
|
@ -4030,7 +4087,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseGeIntScalarModule_basic",
|
"ElementwiseGeIntScalarModule_basic",
|
||||||
"ElementwiseGeIntTensorModule_basic",
|
"ElementwiseGeIntTensorModule_basic",
|
||||||
"ElementwiseGeMixedIntScalarModule_basic",
|
"ElementwiseGeMixedIntScalarModule_basic",
|
||||||
"ElementwiseGeluModule_basic",
|
|
||||||
"ElementwiseGtMixed2ScalarModule_basic",
|
"ElementwiseGtMixed2ScalarModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||||
"ElementwiseIsinfModule_basic",
|
"ElementwiseIsinfModule_basic",
|
||||||
|
@ -4084,9 +4140,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||||
"ElementwiseWhereScalarOtherModule_basic",
|
"ElementwiseWhereScalarOtherModule_basic",
|
||||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
|
||||||
"ElementwiseWhereScalarSelfModule_basic",
|
"ElementwiseWhereScalarSelfModule_basic",
|
||||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
|
||||||
"ElementwiseWhereSelfModule_basic",
|
"ElementwiseWhereSelfModule_basic",
|
||||||
"EmbeddingModule1DIndices_basic",
|
"EmbeddingModule1DIndices_basic",
|
||||||
"EmbeddingModuleF16_basic",
|
"EmbeddingModuleF16_basic",
|
||||||
|
@ -4144,8 +4198,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"HBC_basic",
|
"HBC_basic",
|
||||||
"HardTanhIntModule_basic",
|
"HardTanhIntModule_basic",
|
||||||
"HardTanhModule_basic",
|
"HardTanhModule_basic",
|
||||||
"HardsigmoidModule_basic",
|
|
||||||
"HardsigmoidRandomModule_basic",
|
|
||||||
"HardtanhBackward_basic",
|
"HardtanhBackward_basic",
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
|
@ -4216,7 +4268,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
|
||||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||||
"IntFloatModule_basic",
|
"IntFloatModule_basic",
|
||||||
"IntImplicitModule_basic",
|
"IntImplicitModule_basic",
|
||||||
|
|
|
@ -1356,3 +1356,20 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v
|
||||||
%1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32>
|
%1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32>
|
||||||
return %1 : !torch.vtensor<[1,16,270,480],f32>
|
return %1 : !torch.vtensor<[1,16,270,480],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.tril$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],si32>) -> !torch.vtensor<[2,4],si32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> {
|
||||||
|
%int0 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
|
||||||
|
return %0 : !torch.vtensor<[2,4],si32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue