[TOSA] Add promote type to unary ops and aten.cat lowering (#3860)

Change-Id: I2699bf9007723fe629edb1c524c10ef8142e0234

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3868/head
Justin Ngo 2024-11-08 11:23:39 -08:00 committed by GitHub
parent b6f04fa32b
commit 8eb34dae78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 15 deletions

View File

@ -74,11 +74,16 @@ public:
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
auto self = adaptor.getSelf();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
op.getType()));
self = tosa::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);
return success();
}
};
@ -6091,6 +6096,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
auto builtinTensors =
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
for (auto &tensor : builtinTensors)
tensor = tosa::promoteType(rewriter, tensor, outType);
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
rewriter.replaceOp(op, result.getResult());

View File

@ -1744,6 +1744,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSinIntModule_basic",
"FloatPowerTensorTensorStaticModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"CollapseAllDimensionsModule_basic",
"CollapseRank1DynamicModule_basic",
@ -1786,7 +1792,6 @@ TOSA_PASS_SET = {
"SliceCopy_Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
@ -2435,6 +2440,7 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"IsInfiniteModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18StaticModule_basic",
@ -2510,6 +2516,8 @@ MAKE_FX_TOSA_PASS_SET = (
}
) - {
### Test failing in make_fx_tosa but not in tosa
"AdaptiveMaxPool1dDimOneStatic_basic",
"FloatPowerTensorTensorStaticModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
@ -3390,6 +3398,11 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"IsInfiniteModule_basic",
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
"SelfAttentionFwAndBwModule_basic",
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
@ -3417,9 +3430,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
@ -3597,8 +3607,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
@ -3620,10 +3628,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic",
@ -3850,8 +3855,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TensorToFloat_basic",
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"ThresholdBackward2dMixedModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
@ -3931,6 +3934,8 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"FloatPowerTensorTensorStaticModule_basic",
"IsInfiniteModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseFracModule_basic",
"ElementwiseLdexpModule_basic",