[MLIR][TORCH] Add E2E support for aten.div.float op

This commit adds lowering of `aten.div.float` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/802/head
Vivek Khandelwal 2022-04-25 17:36:41 +05:30
parent 73cc2ac152
commit f5b6c4b601
7 changed files with 111 additions and 0 deletions

View File

@ -6420,6 +6420,31 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
}];
}
def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::div.float : (float, float) -> (float)`";
let arguments = (ins
Torch_FloatType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDivFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenDivFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -214,6 +214,9 @@ public:
target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
target.addIllegalOp<AtenDivFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -94,6 +94,10 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
return IntegerAttr::get(IntegerType::get(context, 64), value);
}
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}
//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
@ -1515,6 +1519,23 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenDivFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
double lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
if (lConstant && lhs == 0.0)
return getF64FloatAttr(getContext(), 0.0);
if (lConstant && rConstant && rhs == 1.0)
return getF64FloatAttr(getContext(), lhs);
if (lConstant && rConstant)
return getF64FloatAttr(getContext(), lhs / rhs);
return nullptr;
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//

View File

@ -488,6 +488,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::sub.float : (float, float) -> (float)")
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
emit("aten::neg.float : (float) -> (float)")
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)

View File

@ -90,3 +90,21 @@ def MulIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
# ==============================================================================
class DivFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.float64, True),
])
def forward(self, lhs, rhs):
return float(lhs)/float(rhs)
@register_test_case(module_factory=lambda: DivFloatModule())
def DivFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(()).double(), torch.rand(()).double())

View File

@ -153,3 +153,16 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.div.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
// CHECK: return %[[OUT:.*]] : !torch.float
func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}

View File

@ -1132,3 +1132,33 @@ func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<int> -> !torch.tensor<[?],f32>
return %1 : !torch.tensor<[?],f32>
}
// CHECK-LABEL: func @torch.aten.div.float$fold_zero_dividend(
// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00
// CHECK: return %[[CST0]] : !torch.float
func @torch.aten.div.float$fold_zero_dividend() -> !torch.float {
%float0 = torch.constant.float 0.0
%float5 = torch.constant.float 5.0
%0 = torch.aten.div.float %float0, %float5 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.div.float$fold_one_divisor(
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
// CHECK: return %[[CST4]] : !torch.float
func @torch.aten.div.float$fold_one_divisor() -> !torch.float {
%float4 = torch.constant.float 4.0
%float1 = torch.constant.float 1.0
%0 = torch.aten.div.float %float4, %float1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.div.float$fold_cst_operands(
// CHECK: %[[CST2:.*]] = torch.constant.float 2.000000e+00
// CHECK: return %[[CST2]] : !torch.float
func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
%float4 = torch.constant.float 4.0
%float2 = torch.constant.float 2.0
%0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}