diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index df5ed5fa8..c033dad1b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -74,11 +74,16 @@ public: LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, + auto self = adaptor.getSelf(); + + auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); + op.getType())); + + self = tosa::promoteType(rewriter, self, outType); + + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); } }; @@ -6091,6 +6096,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto builtinTensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); + for (auto &tensor : builtinTensors) + tensor = tosa::promoteType(rewriter, tensor, outType); + auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); rewriter.replaceOp(op, result.getResult()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8d7aa88ad..2acc3afe5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1744,6 +1744,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSinIntModule_basic", + "FloatPowerTensorTensorStaticModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "CollapseAllDimensionsModule_basic", "CollapseRank1DynamicModule_basic", @@ -1786,7 +1792,6 @@ TOSA_PASS_SET = { "SliceCopy_Module_basic", "Threshold1dIntModule_basic", "Threshold2dIntModule_basic", - "Threshold3dIntModule_basic", "EmptyModule_contiguous", "EmptyModule_defaultDtype", "EmptyModule_falsePinMemory", @@ -2435,6 +2440,7 @@ MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "IsInfiniteModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "ResNet18StaticModule_basic", @@ -2510,6 +2516,8 @@ MAKE_FX_TOSA_PASS_SET = ( } ) - { ### Test failing in make_fx_tosa but not in tosa + "AdaptiveMaxPool1dDimOneStatic_basic", + "FloatPowerTensorTensorStaticModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' @@ -3390,6 +3398,11 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | { } FX_IMPORTER_TOSA_XFAIL_SET = { + "IsInfiniteModule_basic", + "LayerNormFwAndBwModule_basic", + "LayerNormManualFwAndBwModule_basic", + "SelfAttentionFwAndBwModule_basic", + "Threshold3dIntModule_basic", "ElementwiseCopysignModule_basic", "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", @@ -3417,9 +3430,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", "HstackBasicComplexModule_basic", - "HstackBasicFloatModule_basic", - "HstackBasicIntFloatModule_basic", - "HstackBasicIntModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", @@ -3597,8 +3607,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", - "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", @@ -3620,10 +3628,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseMulTensorComplexModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", - "ElementwiseReciprocalIntModule_basic", - "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", - "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseTanIntModule_basic", @@ -3850,8 +3855,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "TensorToFloat_basic", "TensorToIntZeroRank_basic", "TensorToInt_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsStackPromoteDTypeModule_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic", "ThresholdBackward2dMixedModule_basic", "ToCopyWithDTypeFalsePinMemoryModule_basic", @@ -3931,6 +3934,8 @@ ONNX_TOSA_CRASHING_SET = { } ONNX_TOSA_XFAIL_SET = { + "FloatPowerTensorTensorStaticModule_basic", + "IsInfiniteModule_basic", "ElementwiseCopysignModule_basic", "ElementwiseFracModule_basic", "ElementwiseLdexpModule_basic",