Refactor TorchToTosa with separate construction of legal/illegal ops and conversion patterns.

pull/3759/head
Sayan Saha 2024-11-06 22:53:12 -05:00
parent 30c519369e
commit ffa472fb4b
2 changed files with 332 additions and 305 deletions

View File

@ -12,12 +12,25 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
namespace torch {
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
/// return the set of illegalOps
std::set<StringRef>
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H

View File

@ -7215,319 +7215,22 @@ public:
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();
target.addIllegalDialect<Torch::TorchDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
// The following ops are never the primary reason why lowering fails.
// The backend contract only allows functions to return tensors thus there
// is always another op using them.
// When we have a chain of torch.constant.int followed by a unsupported
// torch op, we want the pass to mention the unsupported torch op
// in the error message.
target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantBoolOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<ConstantFloatOp>();
target.addLegalOp<ConstantStrOp>();
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addIllegalDialect<Torch::TorchDialect>();
populateTorchToTosaConversionLegalOps(target);
RewritePatternSet patterns(context);
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_FPONLY_PATTERN
auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps(
typeConverter, patterns);
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
tosa::ArithmeticRightShiftOp)
#undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
#undef INSERT_BINARY_MUL_PATTERN
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
#undef INSERT_BINARY_DIV_PATTERN
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
mlir::tosa::convertReduceMeanOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
mlir::tosa::convertReduceSumOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp,
mlir::tosa::convertLinalgVectorNormOp)
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
mlir::tosa::convertReduceAllOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
target.addIllegalOp<AtenOp>(); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
#undef INSERT_SQUEEZE_OP_PATTERN
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
context);
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool1dOp>();
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp>();
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context);
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
INSERT_FILL_PATTERN(AtenFillScalarOp);
INSERT_FILL_PATTERN(AtenFillTensorOp);
#undef INSERT_FILL_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN
#define INSERT_POW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
#undef INSERT_POW_OP_PATTERN
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReshapeOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenDropoutOp);
INSERT_ATENOP_PATTERN(AtenViewOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp);
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN
for (auto op : illegalOps) {
target.addIllegalOp(OperationName(op, context));
}
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@ -7536,6 +7239,317 @@ public:
};
} // namespace
void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
// The following ops are never the primary reason why lowering fails.
// The backend contract only allows functions to return tensors thus there
// is always another op using them.
// When we have a chain of torch.constant.int followed by a unsupported
// torch op, we want the pass to mention the unsupported torch op
// in the error message.
target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantBoolOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<ConstantFloatOp>();
target.addLegalOp<ConstantStrOp>();
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
}
std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
std::set<StringRef> illegalOps;
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
tosa::ArithmeticRightShiftOp)
#undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
#undef INSERT_BINARY_MUL_PATTERN
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
#undef INSERT_BINARY_DIV_PATTERN
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
mlir::tosa::convertReduceMeanOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
mlir::tosa::convertReduceSumOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp,
mlir::tosa::convertLinalgVectorNormOp)
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
mlir::tosa::convertReduceAllOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
#undef INSERT_SQUEEZE_OP_PATTERN
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
context);
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
illegalOps.insert(AtenMaxPool2dOp::getOperationName());
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
illegalOps.insert(AtenMaxPool1dOp::getOperationName());
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
illegalOps.insert(AtenAvgPool2dOp::getOperationName());
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
illegalOps.insert(AtenAvgPool1dOp::getOperationName());
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context);
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
INSERT_FILL_PATTERN(AtenFillScalarOp);
INSERT_FILL_PATTERN(AtenFillTensorOp);
#undef INSERT_FILL_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN
#define INSERT_POW_OP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
#undef INSERT_POW_OP_PATTERN
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReshapeOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenDropoutOp);
INSERT_ATENOP_PATTERN(AtenViewOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp);
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN
return illegalOps;
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaPass() {
return std::make_unique<ConvertTorchToTosa>();