Refactor TorchToTosa with separate construction of legal/illegal ops and conversion patterns.

pull/3759/head
Sayan Saha 2024-11-06 22:53:12 -05:00
parent 30c519369e
commit ffa472fb4b
2 changed files with 332 additions and 305 deletions

View File

@ -12,12 +12,25 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
namespace torch {
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
/// return the set of illegalOps
std::set<StringRef>
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H

View File

@ -7215,11 +7215,31 @@ public:
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();
target.addIllegalDialect<Torch::TorchDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateTorchToTosaConversionLegalOps(target);
RewritePatternSet patterns(context);
auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps(
typeConverter, patterns);
for (auto op : illegalOps) {
target.addIllegalOp(OperationName(op, context));
}
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
// The following ops are never the primary reason why lowering fails.
// The backend contract only allows functions to return tensors thus there
// is always another op using them.
@ -7234,12 +7254,16 @@ public:
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addIllegalDialect<Torch::TorchDialect>();
}
RewritePatternSet patterns(context);
std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
std::set<StringRef> illegalOps;
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
@ -7247,7 +7271,7 @@ public:
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
@ -7261,21 +7285,20 @@ public:
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
tosa::ArithmeticRightShiftOp)
#undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
@ -7284,7 +7307,7 @@ public:
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
@ -7305,14 +7328,14 @@ public:
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
#undef INSERT_BINARY_MUL_PATTERN
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
@ -7321,7 +7344,7 @@ public:
#undef INSERT_BINARY_DIV_PATTERN
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
@ -7330,7 +7353,7 @@ public:
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
@ -7342,7 +7365,7 @@ public:
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
@ -7354,78 +7377,73 @@ public:
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
#undef INSERT_SQUEEZE_OP_PATTERN
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
context);
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
target.addIllegalOp<AtenMaxPool2dOp>();
illegalOps.insert(AtenMaxPool2dOp::getOperationName());
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool1dOp>();
illegalOps.insert(AtenMaxPool1dOp::getOperationName());
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
illegalOps.insert(AtenAvgPool2dOp::getOperationName());
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp>();
illegalOps.insert(AtenAvgPool1dOp::getOperationName());
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context);
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
@ -7434,7 +7452,7 @@ public:
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
INSERT_FILL_PATTERN(AtenFillScalarOp);
@ -7442,14 +7460,14 @@ public:
#undef INSERT_FILL_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN
#define INSERT_POW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
@ -7457,7 +7475,7 @@ public:
#undef INSERT_POW_OP_PATTERN
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
@ -7466,7 +7484,7 @@ public:
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
@ -7524,17 +7542,13 @@ public:
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
return illegalOps;
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaPass() {