mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.div.float op
This commit adds lowering of `aten.div.float` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/802/head
parent
73cc2ac152
commit
f5b6c4b601
|
@ -6420,6 +6420,31 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::div.float : (float, float) -> (float)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a,
|
||||||
|
Torch_FloatType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_FloatType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenDivFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenDivFloatOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
|
def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -214,6 +214,9 @@ public:
|
||||||
target.addIllegalOp<AtenSubFloatOp>();
|
target.addIllegalOp<AtenSubFloatOp>();
|
||||||
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenDivFloatOp>();
|
||||||
|
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||||
|
typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -94,6 +94,10 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
||||||
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
||||||
|
return FloatAttr::get(Float64Type::get(context), value);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MethodOp
|
// MethodOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1515,6 +1519,23 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenDivFloatOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
double lhs, rhs;
|
||||||
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
||||||
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
||||||
|
if (lConstant && lhs == 0.0)
|
||||||
|
return getF64FloatAttr(getContext(), 0.0);
|
||||||
|
if (lConstant && rConstant && rhs == 1.0)
|
||||||
|
return getF64FloatAttr(getContext(), lhs);
|
||||||
|
if (lConstant && rConstant)
|
||||||
|
return getF64FloatAttr(getContext(), lhs / rhs);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -488,6 +488,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::add.float_int : (float, int) -> (float)")
|
emit("aten::add.float_int : (float, int) -> (float)")
|
||||||
emit("aten::sub.float : (float, float) -> (float)")
|
emit("aten::sub.float : (float, float) -> (float)")
|
||||||
emit("aten::mul.float : (float, float) -> (float)")
|
emit("aten::mul.float : (float, float) -> (float)")
|
||||||
|
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
||||||
emit("aten::neg.float : (float) -> (float)")
|
emit("aten::neg.float : (float) -> (float)")
|
||||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||||
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
||||||
|
|
|
@ -90,3 +90,21 @@ def MulIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class DivFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return float(lhs)/float(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: DivFloatModule())
|
||||||
|
def DivFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||||
|
|
|
@ -153,3 +153,16 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
||||||
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
|
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
|
||||||
return %0 : !torch.int
|
return %0 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.div.float(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
|
||||||
|
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
|
||||||
|
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
|
||||||
|
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
|
||||||
|
// CHECK: return %[[OUT:.*]] : !torch.float
|
||||||
|
func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
|
||||||
|
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
|
@ -1132,3 +1132,33 @@ func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32
|
||||||
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<int> -> !torch.tensor<[?],f32>
|
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<int> -> !torch.tensor<[?],f32>
|
||||||
return %1 : !torch.tensor<[?],f32>
|
return %1 : !torch.tensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.div.float$fold_zero_dividend(
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: return %[[CST0]] : !torch.float
|
||||||
|
func @torch.aten.div.float$fold_zero_dividend() -> !torch.float {
|
||||||
|
%float0 = torch.constant.float 0.0
|
||||||
|
%float5 = torch.constant.float 5.0
|
||||||
|
%0 = torch.aten.div.float %float0, %float5 : !torch.float, !torch.float -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.div.float$fold_one_divisor(
|
||||||
|
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
|
||||||
|
// CHECK: return %[[CST4]] : !torch.float
|
||||||
|
func @torch.aten.div.float$fold_one_divisor() -> !torch.float {
|
||||||
|
%float4 = torch.constant.float 4.0
|
||||||
|
%float1 = torch.constant.float 1.0
|
||||||
|
%0 = torch.aten.div.float %float4, %float1 : !torch.float, !torch.float -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.div.float$fold_cst_operands(
|
||||||
|
// CHECK: %[[CST2:.*]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK: return %[[CST2]] : !torch.float
|
||||||
|
func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
|
||||||
|
%float4 = torch.constant.float 4.0
|
||||||
|
%float2 = torch.constant.float 2.0
|
||||||
|
%0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue