mirror of https://github.com/llvm/torch-mlir
Refactor TorchToTosa with separate construction of legal/illegal ops and conversion patterns.
parent
30c519369e
commit
ffa472fb4b
|
@ -12,12 +12,25 @@
|
|||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
|
||||
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
|
||||
/// dialect.
|
||||
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
|
||||
|
||||
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
|
||||
/// return the set of illegalOps
|
||||
std::set<StringRef>
|
||||
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
||||
}
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
|
|
|
@ -7215,11 +7215,31 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
||||
arith::ArithDialect>();
|
||||
target.addIllegalDialect<Torch::TorchDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
populateTorchToTosaConversionLegalOps(target);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps(
|
||||
typeConverter, patterns);
|
||||
|
||||
for (auto op : illegalOps) {
|
||||
target.addIllegalOp(OperationName(op, context));
|
||||
}
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
|
||||
// The following ops are never the primary reason why lowering fails.
|
||||
// The backend contract only allows functions to return tensors thus there
|
||||
// is always another op using them.
|
||||
|
@ -7234,12 +7254,16 @@ public:
|
|||
target.addLegalOp<ConstantDeviceOp>();
|
||||
target.addLegalOp<PrimListConstructOp>();
|
||||
target.addLegalOp<PrimTupleConstructOp>();
|
||||
target.addIllegalDialect<Torch::TorchDialect>();
|
||||
}
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
|
||||
MLIRContext *context = patterns.getContext();
|
||||
std::set<StringRef> illegalOps;
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
|
||||
|
@ -7247,7 +7271,7 @@ public:
|
|||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
||||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||
|
@ -7261,21 +7285,20 @@ public:
|
|||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||
tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||
tosa::ArithmeticRightShiftOp)
|
||||
#undef INSERT_BINARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
|
||||
|
@ -7284,7 +7307,7 @@ public:
|
|||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
||||
|
@ -7305,14 +7328,14 @@ public:
|
|||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
|
||||
#undef INSERT_BINARY_MUL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
|
@ -7321,7 +7344,7 @@ public:
|
|||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
|
||||
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
|
||||
|
@ -7330,7 +7353,7 @@ public:
|
|||
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
|
||||
|
@ -7342,7 +7365,7 @@ public:
|
|||
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||
|
@ -7354,78 +7377,73 @@ public:
|
|||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
|
||||
mlir::tosa::convertReduceMaxOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
|
||||
mlir::tosa::convertReduceMinOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
|
||||
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
|
||||
#undef INSERT_SQUEEZE_OP_PATTERN
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
|
||||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
illegalOps.insert(AtenMaxPool2dOp::getOperationName());
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenMaxPool1dOp>();
|
||||
illegalOps.insert(AtenMaxPool1dOp::getOperationName());
|
||||
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
illegalOps.insert(AtenAvgPool2dOp::getOperationName());
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
illegalOps.insert(AtenAvgPool1dOp::getOperationName());
|
||||
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
context);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||
|
@ -7434,7 +7452,7 @@ public:
|
|||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||
|
@ -7442,14 +7460,14 @@ public:
|
|||
#undef INSERT_FILL_PATTERN
|
||||
|
||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||
#undef INSERT_MASKED_FILL_PATTERN
|
||||
|
||||
#define INSERT_POW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
|
||||
|
@ -7457,7 +7475,7 @@ public:
|
|||
#undef INSERT_POW_OP_PATTERN
|
||||
|
||||
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
|
||||
|
@ -7466,7 +7484,7 @@ public:
|
|||
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
|
@ -7524,17 +7542,13 @@ public:
|
|||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
|
||||
#undef INSERT_CLONE_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return illegalOps;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToTosaPass() {
|
||||
|
|
Loading…
Reference in New Issue