From f5b6c4b601d5dce93162b8e7ab654f9a45346c7a Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 25 Apr 2022 17:36:41 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.div.float op This commit adds lowering of `aten.div.float` op. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++++++++ lib/Conversion/TorchToStd/TorchToStd.cpp | 3 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 21 +++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/scalar.py | 18 +++++++++++ test/Conversion/TorchToStd/basic.mlir | 13 ++++++++ test/Dialect/Torch/canonicalize.mlir | 30 +++++++++++++++++++ 7 files changed, 111 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 045a3f560..7381e94c0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6420,6 +6420,31 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ }]; } +def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::div.float : (float, float) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDivFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenDivFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index a2275b9db..2aef720f9 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -214,6 +214,9 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a56fd9521..21201e107 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -94,6 +94,10 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) { return IntegerAttr::get(IntegerType::get(context, 64), value); } +static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { + return FloatAttr::get(Float64Type::get(context), value); +} + //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// @@ -1515,6 +1519,23 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenDivFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { + double lhs, rhs; + bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); + bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); + if (lConstant && lhs == 0.0) + return getF64FloatAttr(getContext(), 0.0); + if (lConstant && rConstant && rhs == 1.0) + return getF64FloatAttr(getContext(), lhs); + if (lConstant && rConstant) + return getF64FloatAttr(getContext(), lhs / rhs); + return nullptr; +} + //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b9436df62..799b4a0d5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -488,6 +488,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::add.float_int : (float, int) -> (float)") emit("aten::sub.float : (float, float) -> (float)") emit("aten::mul.float : (float, float) -> (float)") + emit("aten::div.float : (float, float) -> (float)", has_folder=True) emit("aten::neg.float : (float) -> (float)") emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index de75b8751..38a9d7fa8 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -90,3 +90,21 @@ def MulIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) # ============================================================================== + +class DivFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs)/float(rhs) + + +@register_test_case(module_factory=lambda: DivFloatModule()) +def DivFloatModule_basic(module, tu: TestUtils): + module.forward(torch.rand(()).double(), torch.rand(()).double()) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 31e715a11..09441d61b 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -153,3 +153,16 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { %0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int return %0 : !torch.int } + +// CHECK-LABEL: func @torch.aten.div.float( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]] +// CHECK: return %[[OUT:.*]] : !torch.float +func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float { + %0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index ef0d45e35..263603d68 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1132,3 +1132,33 @@ func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32 %1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list -> !torch.tensor<[?],f32> return %1 : !torch.tensor<[?],f32> } + +// CHECK-LABEL: func @torch.aten.div.float$fold_zero_dividend( +// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: return %[[CST0]] : !torch.float +func @torch.aten.div.float$fold_zero_dividend() -> !torch.float { + %float0 = torch.constant.float 0.0 + %float5 = torch.constant.float 5.0 + %0 = torch.aten.div.float %float0, %float5 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + +// CHECK-LABEL: func @torch.aten.div.float$fold_one_divisor( +// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00 +// CHECK: return %[[CST4]] : !torch.float +func @torch.aten.div.float$fold_one_divisor() -> !torch.float { + %float4 = torch.constant.float 4.0 + %float1 = torch.constant.float 1.0 + %0 = torch.aten.div.float %float4, %float1 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + +// CHECK-LABEL: func @torch.aten.div.float$fold_cst_operands( +// CHECK: %[[CST2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: return %[[CST2]] : !torch.float +func @torch.aten.div.float$fold_cst_operands() -> !torch.float { + %float4 = torch.constant.float 4.0 + %float2 = torch.constant.float 2.0 + %0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +}