mirror of https://github.com/llvm/torch-mlir
[TOSA] Add promote type to unary ops and aten.cat lowering (#3860)
Change-Id: I2699bf9007723fe629edb1c524c10ef8142e0234 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3868/head
parent
b6f04fa32b
commit
8eb34dae78
|
@ -74,11 +74,16 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
auto self = adaptor.getSelf();
|
||||||
op,
|
|
||||||
|
auto outType = dyn_cast<TensorType>(
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()));
|
||||||
adaptor.getSelf());
|
|
||||||
|
self = tosa::promoteType(rewriter, self, outType);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -6091,6 +6096,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
auto builtinTensors =
|
auto builtinTensors =
|
||||||
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
||||||
|
|
||||||
|
for (auto &tensor : builtinTensors)
|
||||||
|
tensor = tosa::promoteType(rewriter, tensor, outType);
|
||||||
|
|
||||||
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
|
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
|
||||||
rewriter.replaceOp(op, result.getResult());
|
rewriter.replaceOp(op, result.getResult());
|
||||||
|
|
|
@ -1744,6 +1744,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||||
|
"ElementwiseCosIntModule_basic",
|
||||||
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
|
"ElementwiseRsqrtIntModule_basic",
|
||||||
|
"ElementwiseSinIntModule_basic",
|
||||||
|
"FloatPowerTensorTensorStaticModule_basic",
|
||||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"CollapseAllDimensionsModule_basic",
|
"CollapseAllDimensionsModule_basic",
|
||||||
"CollapseRank1DynamicModule_basic",
|
"CollapseRank1DynamicModule_basic",
|
||||||
|
@ -1786,7 +1792,6 @@ TOSA_PASS_SET = {
|
||||||
"SliceCopy_Module_basic",
|
"SliceCopy_Module_basic",
|
||||||
"Threshold1dIntModule_basic",
|
"Threshold1dIntModule_basic",
|
||||||
"Threshold2dIntModule_basic",
|
"Threshold2dIntModule_basic",
|
||||||
"Threshold3dIntModule_basic",
|
|
||||||
"EmptyModule_contiguous",
|
"EmptyModule_contiguous",
|
||||||
"EmptyModule_defaultDtype",
|
"EmptyModule_defaultDtype",
|
||||||
"EmptyModule_falsePinMemory",
|
"EmptyModule_falsePinMemory",
|
||||||
|
@ -2435,6 +2440,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"IsInfiniteModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
"ResNet18StaticModule_basic",
|
"ResNet18StaticModule_basic",
|
||||||
|
@ -2510,6 +2516,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
}
|
}
|
||||||
) - {
|
) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
|
"FloatPowerTensorTensorStaticModule_basic",
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||||
|
@ -3390,6 +3398,11 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"IsInfiniteModule_basic",
|
||||||
|
"LayerNormFwAndBwModule_basic",
|
||||||
|
"LayerNormManualFwAndBwModule_basic",
|
||||||
|
"SelfAttentionFwAndBwModule_basic",
|
||||||
|
"Threshold3dIntModule_basic",
|
||||||
"ElementwiseCopysignModule_basic",
|
"ElementwiseCopysignModule_basic",
|
||||||
"ElementwiseSignbitModule_basic",
|
"ElementwiseSignbitModule_basic",
|
||||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||||
|
@ -3417,9 +3430,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenPolarFloatModule_basic",
|
"AtenPolarFloatModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
|
||||||
"HstackBasicIntFloatModule_basic",
|
|
||||||
"HstackBasicIntModule_basic",
|
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
@ -3597,8 +3607,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
|
||||||
"ElementwiseCosIntModule_basic",
|
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
|
@ -3620,10 +3628,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
|
||||||
"ElementwiseRsqrtIntModule_basic",
|
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
"ElementwiseSinIntModule_basic",
|
|
||||||
"ElementwiseSinhIntModule_basic",
|
"ElementwiseSinhIntModule_basic",
|
||||||
"ElementwiseSinhModule_basic",
|
"ElementwiseSinhModule_basic",
|
||||||
"ElementwiseTanIntModule_basic",
|
"ElementwiseTanIntModule_basic",
|
||||||
|
@ -3850,8 +3855,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"TensorToFloat_basic",
|
"TensorToFloat_basic",
|
||||||
"TensorToIntZeroRank_basic",
|
"TensorToIntZeroRank_basic",
|
||||||
"TensorToInt_basic",
|
"TensorToInt_basic",
|
||||||
"TensorsConcatPromoteDTypeModule_basic",
|
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
|
||||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||||
"ThresholdBackward2dMixedModule_basic",
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||||
|
@ -3931,6 +3934,8 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"FloatPowerTensorTensorStaticModule_basic",
|
||||||
|
"IsInfiniteModule_basic",
|
||||||
"ElementwiseCopysignModule_basic",
|
"ElementwiseCopysignModule_basic",
|
||||||
"ElementwiseFracModule_basic",
|
"ElementwiseFracModule_basic",
|
||||||
"ElementwiseLdexpModule_basic",
|
"ElementwiseLdexpModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue