mirror of https://github.com/llvm/torch-mlir
[mlir][tosa] Refactor conversions to use templates (#416)
- Remove use of conversion construction macros - Add mul and div op conversions - Add corresponding tests Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/419/head snapshot-20211112.79
parent
1019ddf5a0
commit
628a21bb13
|
@ -28,4 +28,5 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"TanhBackward_basic",
|
||||
}
|
||||
|
|
|
@ -23,116 +23,112 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
// These legalizations are for unary ops with only for FP datatypes.
|
||||
|
||||
// These legalizations are for unary ops with only for floating point datatypes.
|
||||
// There is no supported quantized integer mode for these.
|
||||
#define DEF_FULLCONV_FPONLY_UNARY_ATENOP(aten_op, tosa_op) \
|
||||
class ConvertAten##aten_op##Op \
|
||||
: public OpConversionPattern<Aten##aten_op##Op> { \
|
||||
public: \
|
||||
using OpConversionPattern::OpConversionPattern; \
|
||||
LogicalResult \
|
||||
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
|
||||
ConversionPatternRewriter &rewriter) const override { \
|
||||
Aten##aten_op##Op::Adaptor adaptor(operands); \
|
||||
Value self = adaptor.self(); \
|
||||
auto selfTy = self.getType().cast<TensorType>(); \
|
||||
if (selfTy) { \
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) { \
|
||||
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
|
||||
op, getTypeConverter()->convertType(op.getType()), self); \
|
||||
return success(); \
|
||||
} else { \
|
||||
return op.emitError("Only FP type legalization supported"); \
|
||||
} \
|
||||
} else { \
|
||||
return op.emitError("Only Tensor types supported in TOSA"); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
DEF_FULLCONV_FPONLY_UNARY_ATENOP(Log, Log)
|
||||
DEF_FULLCONV_FPONLY_UNARY_ATENOP(Exp, Exp)
|
||||
#undef DEF_FULLCONV_FPONLY_UNARY_ATENOP
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
typename AtenOpT::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
||||
// These unary op legalizations are identical for FP or quantized types
|
||||
#define DEF_FULLCONV_UNARY_ATENOP(aten_op, tosa_op) \
|
||||
class ConvertAten##aten_op##Op \
|
||||
: public OpConversionPattern<Aten##aten_op##Op> { \
|
||||
public: \
|
||||
using OpConversionPattern::OpConversionPattern; \
|
||||
LogicalResult \
|
||||
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
|
||||
ConversionPatternRewriter &rewriter) const override { \
|
||||
Aten##aten_op##Op::Adaptor adaptor(operands); \
|
||||
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self()); \
|
||||
return success(); \
|
||||
} \
|
||||
};
|
||||
DEF_FULLCONV_UNARY_ATENOP(Neg, Negate)
|
||||
DEF_FULLCONV_UNARY_ATENOP(Floor, Floor)
|
||||
DEF_FULLCONV_UNARY_ATENOP(BitwiseNot, BitwiseNot)
|
||||
#undef DEF_FULLCONV_UNARY_ATENOP
|
||||
if (!selfTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
// These binary op legalizations are identical for FP or quantized types
|
||||
#define DEF_FULLCONV_ADDSUB_ATENOP(aten_op, tosa_op) \
|
||||
class ConvertAten##aten_op##Op \
|
||||
: public OpConversionPattern<Aten##aten_op##Op> { \
|
||||
public: \
|
||||
using OpConversionPattern::OpConversionPattern; \
|
||||
LogicalResult matchAndRewrite(Aten##aten_op##Op op, \
|
||||
ArrayRef<Value> operands, \
|
||||
ConversionPatternRewriter &rewriter) const { \
|
||||
Aten##aten_op##Op::Adaptor adaptor(operands); \
|
||||
\
|
||||
Value lhs = adaptor.self(); \
|
||||
auto lhsTy = lhs.getType().cast<TensorType>(); \
|
||||
Value rhs = adaptor.other(); \
|
||||
auto rhsTy = rhs.getType().cast<TensorType>(); \
|
||||
\
|
||||
if (!lhsTy || !rhsTy) \
|
||||
return op.emitError("Only Tensor types supported in TOSA"); \
|
||||
\
|
||||
auto lhsElemTy = lhsTy.getElementType(); \
|
||||
auto rhsElemTy = rhsTy.getElementType(); \
|
||||
\
|
||||
if (lhsElemTy != rhsElemTy) \
|
||||
return op.emitError("Add: input datatypes mismatched"); \
|
||||
\
|
||||
/* FIXME: Handle alpha. \
|
||||
Needs extraction of floating point constant. */ \
|
||||
\
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) { \
|
||||
rewriter.replaceOpWithNewOp<tosa::tosa_op##Op>( \
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs, rhs); \
|
||||
return success(); \
|
||||
} else { \
|
||||
return op.emitError("Only FP type legalization supported"); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
DEF_FULLCONV_ADDSUB_ATENOP(AddTensor, Add)
|
||||
DEF_FULLCONV_ADDSUB_ATENOP(SubTensor, Sub)
|
||||
#undef DEF_FULLCONV_ADDSUB_ATENOP
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
self);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// These legalizations have both FP and quantized type supported modes.
|
||||
// Their rewriters are expressed below
|
||||
#define DECL_CONVERT_ATENOP(aten_op) \
|
||||
class ConvertAten##aten_op##Op \
|
||||
: public OpConversionPattern<Aten##aten_op##Op> { \
|
||||
public: \
|
||||
using OpConversionPattern::OpConversionPattern; \
|
||||
LogicalResult \
|
||||
matchAndRewrite(Aten##aten_op##Op op, ArrayRef<Value> operands, \
|
||||
ConversionPatternRewriter &rewriter) const override; \
|
||||
};
|
||||
DECL_CONVERT_ATENOP(Tanh)
|
||||
DECL_CONVERT_ATENOP(Sigmoid)
|
||||
DECL_CONVERT_ATENOP(Relu)
|
||||
#undef DECL_CONVERT_ATENOP
|
||||
// These unary op legalizations are identical for floating-point
|
||||
// or quantized types
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
typename AtenOpT::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.self());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
ConvertAtenTanhOp::matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// These binary op legalizations are specific to add/sub which have an
|
||||
// alpha multiplier.
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
typename AtenOpT::Adaptor adaptor(operands);
|
||||
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
|
||||
// FIXME: Handle alpha.
|
||||
// Needs extraction of floating point constant.
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhs);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// This defines a template to construct ops whose legalizations are
|
||||
// specialized.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||
AtenTanhOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
AtenTanhOp::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
@ -143,11 +139,13 @@ ConvertAtenTanhOp::matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
|
|||
} else {
|
||||
// Sigmoid legalization in TOSA for quantized element-type uses
|
||||
// specialized tosa.table construct.
|
||||
return op.emitError("Only FP type legalization currently supported");
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult ConvertAtenSigmoidOp::matchAndRewrite(
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
||||
AtenSigmoidOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
AtenSigmoidOp::Adaptor adaptor(operands);
|
||||
|
@ -160,13 +158,15 @@ LogicalResult ConvertAtenSigmoidOp::matchAndRewrite(
|
|||
} else {
|
||||
// Sigmoid legalization in TOSA for quantized element-type uses
|
||||
// specialized tosa.table construct.
|
||||
return op.emitError("Only FP type legalization currently supported");
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
} // namespace
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||
AtenReluOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
AtenReluOp::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
@ -177,7 +177,8 @@ ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
|
|||
if (selfTy) {
|
||||
// Rescale the clampIn for quantized types. TBD
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
return op.emitError("Only FP type legalization currently supported");
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), clampIn,
|
||||
|
@ -191,6 +192,71 @@ ConvertAtenReluOp::matchAndRewrite(AtenReluOp op, ArrayRef<Value> operands,
|
|||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMulTensorOp>::matchAndRewrite(
|
||||
AtenMulTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
AtenMulTensorOp::Adaptor adaptor(operands);
|
||||
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs, rhs,
|
||||
/*shift=*/0);
|
||||
return success();
|
||||
} else {
|
||||
// Quantized multiplication may need to rescale inputs.
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
|
||||
AtenDivTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs,
|
||||
rcpOp.getResult(), /*shift=*/0);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tosa::DivOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs, rhs);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -216,20 +282,38 @@ public:
|
|||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
#define INSERT_NEW_PATTERN(aten_op) \
|
||||
target.addIllegalOp<Aten##aten_op##Op>(); \
|
||||
patterns.add<ConvertAten##aten_op##Op>(typeConverter, context);
|
||||
INSERT_NEW_PATTERN(Log);
|
||||
INSERT_NEW_PATTERN(Exp);
|
||||
INSERT_NEW_PATTERN(Neg);
|
||||
INSERT_NEW_PATTERN(Floor);
|
||||
INSERT_NEW_PATTERN(BitwiseNot);
|
||||
INSERT_NEW_PATTERN(AddTensor);
|
||||
INSERT_NEW_PATTERN(SubTensor);
|
||||
INSERT_NEW_PATTERN(Tanh);
|
||||
INSERT_NEW_PATTERN(Sigmoid);
|
||||
INSERT_NEW_PATTERN(Relu);
|
||||
#undef INSERT_NEW_PATTERN
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
||||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -113,7 +113,7 @@ func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.add"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
|
||||
|
@ -130,9 +130,40 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
|
|||
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sub"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.mul$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.div$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RCP:.*]] = "tosa.reciprocal"(%[[ARG1_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue