[mlir][tosa] Refactor conversions to use templates (#416)

- Remove use of conversion construction macros
- Add mul and div op conversions
- Add corresponding tests

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/419/head snapshot-20211112.79
Suraj Sudhir 2021-11-11 16:15:58 -08:00 committed by GitHub
parent 1019ddf5a0
commit 628a21bb13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 129 deletions

View File

@ -28,4 +28,5 @@ TOSA_PASS_SET = {
"ElementwiseReluModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic",
"TanhBackward_basic",
}

View File

@ -23,116 +23,112 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
// These legalizations are for unary ops with only for FP datatypes.
// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
#define DEF_FULLCONV_FPONLY_UNARY_ATENOP(aten_op, tosa_op) \
class ConvertAten##aten_op##Op \
: public OpConversionPattern<Aten##aten_op##Op> { \
public: \
using OpConversionPattern::OpConversionPattern; \
LogicalResult \
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
ConversionPatternRewriter &rewriter) const override { \
Aten##aten_op##Op::Adaptor adaptor(operands); \
Value self = adaptor.self(); \
auto selfTy = self.getType().cast<TensorType>(); \
if (selfTy) { \
if (selfTy.getElementType().isa<mlir::FloatType>()) { \
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
op, getTypeConverter()->convertType(op.getType()), self); \
return success(); \
} else { \
return op.emitError("Only FP type legalization supported"); \
} \
} else { \
return op.emitError("Only Tensor types supported in TOSA"); \
} \
} \
};
DEF_FULLCONV_FPONLY_UNARY_ATENOP(Log, Log)
DEF_FULLCONV_FPONLY_UNARY_ATENOP(Exp, Exp)
#undef DEF_FULLCONV_FPONLY_UNARY_ATENOP
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename AtenOpT::Adaptor adaptor(operands);
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();
// These unary op legalizations are identical for FP or quantized types
#define DEF_FULLCONV_UNARY_ATENOP(aten_op, tosa_op) \
class ConvertAten##aten_op##Op \
: public OpConversionPattern<Aten##aten_op##Op> { \
public: \
using OpConversionPattern::OpConversionPattern; \
LogicalResult \
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
ConversionPatternRewriter &rewriter) const override { \
Aten##aten_op##Op::Adaptor adaptor(operands); \
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
op, getTypeConverter()->convertType(op.getType()), adaptor.self()); \
return success(); \
} \
};
DEF_FULLCONV_UNARY_ATENOP(Neg, Negate)
DEF_FULLCONV_UNARY_ATENOP(Floor, Floor)
DEF_FULLCONV_UNARY_ATENOP(BitwiseNot, BitwiseNot)
#undef DEF_FULLCONV_UNARY_ATENOP
if (!selfTy)
return op.emitError("Only Tensor types supported in TOSA");
// These binary op legalizations are identical for FP or quantized types
#define DEF_FULLCONV_ADDSUB_ATENOP(aten_op, tosa_op) \
class ConvertAten##aten_op##Op \
: public OpConversionPattern<Aten##aten_op##Op> { \
public: \
using OpConversionPattern::OpConversionPattern; \
LogicalResult matchAndRewrite(Aten##aten_op##Op op, \
ArrayRef<Value> operands, \
ConversionPatternRewriter &rewriter) const { \
Aten##aten_op##Op::Adaptor adaptor(operands); \
\
Value lhs = adaptor.self(); \
auto lhsTy = lhs.getType().cast<TensorType>(); \
Value rhs = adaptor.other(); \
auto rhsTy = rhs.getType().cast<TensorType>(); \
\
if (!lhsTy || !rhsTy) \
return op.emitError("Only Tensor types supported in TOSA"); \
\
auto lhsElemTy = lhsTy.getElementType(); \
auto rhsElemTy = rhsTy.getElementType(); \
\
if (lhsElemTy != rhsElemTy) \
return op.emitError("Add: input datatypes mismatched"); \
\
/* FIXME: Handle alpha. \
Needs extraction of floating point constant. */ \
\
if (lhsElemTy.isa<mlir::FloatType>()) { \
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
op, getTypeConverter()->convertType(op.getType()), lhs, rhs); \
return success(); \
} else { \
return op.emitError("Only FP type legalization supported"); \
} \
} \
};
DEF_FULLCONV_ADDSUB_ATENOP(AddTensor, Add)
DEF_FULLCONV_ADDSUB_ATENOP(SubTensor, Sub)
#undef DEF_FULLCONV_ADDSUB_ATENOP
if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
self);
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization supported");
}
}
};
// These legalizations have both FP and quantized type supported modes.
// Their rewriters are expressed below
#define DECL_CONVERT_ATENOP(aten_op) \
class ConvertAten##aten_op##Op \
: public OpConversionPattern<Aten##aten_op##Op> { \
public: \
using OpConversionPattern::OpConversionPattern; \
LogicalResult \
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
ConversionPatternRewriter &rewriter) const override; \
};
DECL_CONVERT_ATENOP(Tanh)
DECL_CONVERT_ATENOP(Sigmoid)
DECL_CONVERT_ATENOP(Relu)
#undef DECL_CONVERT_ATENOP
// These unary op legalizations are identical for floating-point
// or quantized types
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename AtenOpT::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
};
LogicalResult
ConvertAtenTanhOp::matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// These binary op legalizations are specific to add/sub which have an
// alpha multiplier.
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
LogicalResult matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
typename AtenOpT::Adaptor adaptor(operands);
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
// FIXME: Handle alpha.
// Needs extraction of floating point constant.
if (lhsElemTy.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhs);
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization supported");
}
}
};
// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AtenTanhOp::Adaptor adaptor(operands);
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();
@ -143,11 +139,13 @@ ConvertAtenTanhOp::matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
} else {
// Sigmoid legalization in TOSA for quantized element-type uses
// specialized tosa.table construct.
return op.emitError("Only FP type legalization currently supported");
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
}
LogicalResult ConvertAtenSigmoidOp::matchAndRewrite(
template <>
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
AtenSigmoidOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AtenSigmoidOp::Adaptor adaptor(operands);
@ -160,13 +158,15 @@ LogicalResult ConvertAtenSigmoidOp::matchAndRewrite(
} else {
// Sigmoid legalization in TOSA for quantized element-type uses
// specialized tosa.table construct.
return op.emitError("Only FP type legalization currently supported");
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
} // namespace
}
LogicalResult
ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
template <>
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
AtenReluOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AtenReluOp::Adaptor adaptor(operands);
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();
@ -177,7 +177,8 @@ ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
if (selfTy) {
// Rescale the clampIn for quantized types. TBD
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
return op.emitError("Only FP type legalization currently supported");
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, getTypeConverter()->convertType(op.getType()), clampIn,
@ -191,6 +192,71 @@ ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
}
}
template <>
LogicalResult ConvertAtenOp<AtenMulTensorOp>::matchAndRewrite(
AtenMulTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AtenMulTensorOp::Adaptor adaptor(operands);
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
if (lhsElemTy.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), lhs, rhs,
/*shift=*/0);
return success();
} else {
// Quantized multiplication may need to rescale inputs.
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
}
template <>
LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
AtenDivTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AtenDivTensorOp::Adaptor adaptor(operands);
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
if (lhsElemTy.isa<mlir::FloatType>()) {
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), lhs,
rcpOp.getResult(), /*shift=*/0);
} else {
rewriter.replaceOpWithNewOp<tosa::DivOp>(
op, getTypeConverter()->convertType(op.getType()), lhs, rhs);
}
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -216,20 +282,38 @@ public:
RewritePatternSet patterns(context);
#define INSERT_NEW_PATTERN(aten_op) \
target.addIllegalOp<Aten##aten_op##Op>(); \
patterns.add<ConvertAten##aten_op##Op>(typeConverter, context);
INSERT_NEW_PATTERN(Log);
INSERT_NEW_PATTERN(Exp);
INSERT_NEW_PATTERN(Neg);
INSERT_NEW_PATTERN(Floor);
INSERT_NEW_PATTERN(BitwiseNot);
INSERT_NEW_PATTERN(AddTensor);
INSERT_NEW_PATTERN(SubTensor);
INSERT_NEW_PATTERN(Tanh);
INSERT_NEW_PATTERN(Sigmoid);
INSERT_NEW_PATTERN(Relu);
#undef INSERT_NEW_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -113,7 +113,7 @@ func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.add"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
@ -130,9 +130,40 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sub"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.mul$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.div$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RCP:.*]] = "tosa.reciprocal"(%[[ARG1_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}