mirror of https://github.com/llvm/torch-mlir
[TOSA] Add promote type to unary ops and aten.cat lowering (#3860)
Change-Id: I2699bf9007723fe629edb1c524c10ef8142e0234 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3868/head
parent
b6f04fa32b
commit
8eb34dae78
|
@ -74,11 +74,16 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
op.getType()));
|
||||
|
||||
self = tosa::promoteType(rewriter, self, outType);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -6091,6 +6096,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
auto builtinTensors =
|
||||
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
||||
|
||||
for (auto &tensor : builtinTensors)
|
||||
tensor = tosa::promoteType(rewriter, tensor, outType);
|
||||
|
||||
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
|
||||
rewriter.replaceOp(op, result.getResult());
|
||||
|
|
|
@ -1744,6 +1744,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"FloatPowerTensorTensorStaticModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseRank1DynamicModule_basic",
|
||||
|
@ -1786,7 +1792,6 @@ TOSA_PASS_SET = {
|
|||
"SliceCopy_Module_basic",
|
||||
"Threshold1dIntModule_basic",
|
||||
"Threshold2dIntModule_basic",
|
||||
"Threshold3dIntModule_basic",
|
||||
"EmptyModule_contiguous",
|
||||
"EmptyModule_defaultDtype",
|
||||
"EmptyModule_falsePinMemory",
|
||||
|
@ -2435,6 +2440,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"IsInfiniteModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
|
@ -2510,6 +2516,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"FloatPowerTensorTensorStaticModule_basic",
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
|
@ -3390,6 +3398,11 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"IsInfiniteModule_basic",
|
||||
"LayerNormFwAndBwModule_basic",
|
||||
"LayerNormManualFwAndBwModule_basic",
|
||||
"SelfAttentionFwAndBwModule_basic",
|
||||
"Threshold3dIntModule_basic",
|
||||
"ElementwiseCopysignModule_basic",
|
||||
"ElementwiseSignbitModule_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
|
@ -3417,9 +3430,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
"HstackBasicIntModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
|
@ -3597,8 +3607,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseAtanhIntModule_basic",
|
||||
"ElementwiseAtanhModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
|
@ -3620,10 +3628,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseSinhIntModule_basic",
|
||||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
|
@ -3850,8 +3855,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"TensorToFloat_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToInt_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||
|
@ -3931,6 +3934,8 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"FloatPowerTensorTensorStaticModule_basic",
|
||||
"IsInfiniteModule_basic",
|
||||
"ElementwiseCopysignModule_basic",
|
||||
"ElementwiseFracModule_basic",
|
||||
"ElementwiseLdexpModule_basic",
|
||||
|
|
Loading…
Reference in New Issue