mirror of https://github.com/llvm/torch-mlir
Refactor TorchToTosa with separate construction of legal/illegal ops and conversion patterns.
parent
30c519369e
commit
ffa472fb4b
|
@ -12,12 +12,25 @@
|
|||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
|
||||
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
|
||||
/// dialect.
|
||||
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
|
||||
|
||||
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
|
||||
/// return the set of illegalOps
|
||||
std::set<StringRef>
|
||||
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
||||
}
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
|
|
|
@ -7215,319 +7215,22 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
||||
arith::ArithDialect>();
|
||||
target.addIllegalDialect<Torch::TorchDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
// The following ops are never the primary reason why lowering fails.
|
||||
// The backend contract only allows functions to return tensors thus there
|
||||
// is always another op using them.
|
||||
// When we have a chain of torch.constant.int followed by a unsupported
|
||||
// torch op, we want the pass to mention the unsupported torch op
|
||||
// in the error message.
|
||||
target.addLegalOp<ConstantNoneOp>();
|
||||
target.addLegalOp<ConstantBoolOp>();
|
||||
target.addLegalOp<ConstantIntOp>();
|
||||
target.addLegalOp<ConstantFloatOp>();
|
||||
target.addLegalOp<ConstantStrOp>();
|
||||
target.addLegalOp<ConstantDeviceOp>();
|
||||
target.addLegalOp<PrimListConstructOp>();
|
||||
target.addLegalOp<PrimTupleConstructOp>();
|
||||
target.addIllegalDialect<Torch::TorchDialect>();
|
||||
populateTorchToTosaConversionLegalOps(target);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps(
|
||||
typeConverter, patterns);
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
||||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
||||
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
||||
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
|
||||
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||
tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||
tosa::ArithmeticRightShiftOp)
|
||||
#undef INSERT_BINARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
|
||||
#undef INSERT_BINARY_MUL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
|
||||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
|
||||
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
|
||||
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
|
||||
mlir::tosa::convertReduceMeanOp)
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp,
|
||||
mlir::tosa::convertLinalgVectorNormOp)
|
||||
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
|
||||
mlir::tosa::convertReduceMaxOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
|
||||
mlir::tosa::convertReduceMinOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
|
||||
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
|
||||
#undef INSERT_SQUEEZE_OP_PATTERN
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
|
||||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenMaxPool1dOp>();
|
||||
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
context);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||
#undef INSERT_FILL_PATTERN
|
||||
|
||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||
#undef INSERT_MASKED_FILL_PATTERN
|
||||
|
||||
#define INSERT_POW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
|
||||
#undef INSERT_POW_OP_PATTERN
|
||||
|
||||
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
|
||||
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReshapeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDropoutOp);
|
||||
INSERT_ATENOP_PATTERN(AtenViewOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
|
||||
#undef INSERT_CLONE_ATENOP_PATTERN
|
||||
for (auto op : illegalOps) {
|
||||
target.addIllegalOp(OperationName(op, context));
|
||||
}
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
@ -7536,6 +7239,317 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
|
||||
// The following ops are never the primary reason why lowering fails.
|
||||
// The backend contract only allows functions to return tensors thus there
|
||||
// is always another op using them.
|
||||
// When we have a chain of torch.constant.int followed by a unsupported
|
||||
// torch op, we want the pass to mention the unsupported torch op
|
||||
// in the error message.
|
||||
target.addLegalOp<ConstantNoneOp>();
|
||||
target.addLegalOp<ConstantBoolOp>();
|
||||
target.addLegalOp<ConstantIntOp>();
|
||||
target.addLegalOp<ConstantFloatOp>();
|
||||
target.addLegalOp<ConstantStrOp>();
|
||||
target.addLegalOp<ConstantDeviceOp>();
|
||||
target.addLegalOp<PrimListConstructOp>();
|
||||
target.addLegalOp<PrimTupleConstructOp>();
|
||||
}
|
||||
|
||||
std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
|
||||
MLIRContext *context = patterns.getContext();
|
||||
std::set<StringRef> illegalOps;
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
||||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
||||
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
||||
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
|
||||
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||
tosa::ArithmeticRightShiftOp)
|
||||
#undef INSERT_BINARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
|
||||
#undef INSERT_BINARY_MUL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
|
||||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
|
||||
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
|
||||
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
|
||||
mlir::tosa::convertReduceMeanOp)
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp,
|
||||
mlir::tosa::convertLinalgVectorNormOp)
|
||||
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
|
||||
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
|
||||
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
|
||||
#undef INSERT_SQUEEZE_OP_PATTERN
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
|
||||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
illegalOps.insert(AtenMaxPool2dOp::getOperationName());
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
illegalOps.insert(AtenMaxPool1dOp::getOperationName());
|
||||
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
|
||||
|
||||
illegalOps.insert(AtenAvgPool2dOp::getOperationName());
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
|
||||
illegalOps.insert(AtenAvgPool1dOp::getOperationName());
|
||||
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
context);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||
#undef INSERT_FILL_PATTERN
|
||||
|
||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenMaskedFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp);
|
||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||
#undef INSERT_MASKED_FILL_PATTERN
|
||||
|
||||
#define INSERT_POW_OP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
|
||||
#undef INSERT_POW_OP_PATTERN
|
||||
|
||||
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
|
||||
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReshapeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDropoutOp);
|
||||
INSERT_ATENOP_PATTERN(AtenViewOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
illegalOps.insert(AtenOp::getOperationName()); \
|
||||
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
|
||||
#undef INSERT_CLONE_ATENOP_PATTERN
|
||||
|
||||
return illegalOps;
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToTosaPass() {
|
||||
return std::make_unique<ConvertTorchToTosa>();
|
||||
|
|
Loading…
Reference in New Issue