mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.div.float op
This commit adds lowering of `aten.div.float` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/802/head
parent
73cc2ac152
commit
f5b6c4b601
|
@ -6420,6 +6420,31 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div.float : (float, float) -> (float)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDivFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenDivFloatOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -214,6 +214,9 @@ public:
|
|||
target.addIllegalOp<AtenSubFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenDivFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||
typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -94,6 +94,10 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
|||
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
||||
}
|
||||
|
||||
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
||||
return FloatAttr::get(Float64Type::get(context), value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MethodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1515,6 +1519,23 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
double lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
||||
if (lConstant && lhs == 0.0)
|
||||
return getF64FloatAttr(getContext(), 0.0);
|
||||
if (lConstant && rConstant && rhs == 1.0)
|
||||
return getF64FloatAttr(getContext(), lhs);
|
||||
if (lConstant && rConstant)
|
||||
return getF64FloatAttr(getContext(), lhs / rhs);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -488,6 +488,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
emit("aten::sub.float : (float, float) -> (float)")
|
||||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::neg.float : (float) -> (float)")
|
||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -90,3 +90,21 @@ def MulIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DivFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs)/float(rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DivFloatModule())
|
||||
def DivFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||
|
|
|
@ -153,3 +153,16 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
|||
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.div.float(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
|
||||
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||
// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
|
||||
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
|
||||
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
|
||||
// CHECK: return %[[OUT:.*]] : !torch.float
|
||||
func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
|
||||
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
|
|
@ -1132,3 +1132,33 @@ func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32
|
|||
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<int> -> !torch.tensor<[?],f32>
|
||||
return %1 : !torch.tensor<[?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.div.float$fold_zero_dividend(
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: return %[[CST0]] : !torch.float
|
||||
func @torch.aten.div.float$fold_zero_dividend() -> !torch.float {
|
||||
%float0 = torch.constant.float 0.0
|
||||
%float5 = torch.constant.float 5.0
|
||||
%0 = torch.aten.div.float %float0, %float5 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.div.float$fold_one_divisor(
|
||||
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
|
||||
// CHECK: return %[[CST4]] : !torch.float
|
||||
func @torch.aten.div.float$fold_one_divisor() -> !torch.float {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float1 = torch.constant.float 1.0
|
||||
%0 = torch.aten.div.float %float4, %float1 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.div.float$fold_cst_operands(
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.float 2.000000e+00
|
||||
// CHECK: return %[[CST2]] : !torch.float
|
||||
func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float2 = torch.constant.float 2.0
|
||||
%0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue