mirror of https://github.com/llvm/torch-mlir
Revert "[MLIR][TORCH] Add E2E support for aten.ne.float_int op"
This reverts commit 51dd462592
.
pull/808/head
parent
5ef9f501fa
commit
7669ee4e4a
|
@ -6667,30 +6667,6 @@ def Torch_AtenGeFloatIntOp : Torch_Op<"aten.ge.float_int", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenNeFloatIntOp : Torch_Op<"aten.ne.float_int", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::ne.float_int : (float, int) -> (bool)`";
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_FloatType:$a,
|
|
||||||
Torch_IntType:$b
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_BoolType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenNeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
|
||||||
}
|
|
||||||
void AtenNeFloatIntOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -211,16 +211,13 @@ public:
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp>();
|
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp>();
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>(
|
ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
patterns.add<ConvertAtenFloatComparisonOp<AtenGeFloatIntOp,
|
patterns.add<ConvertAtenFloatComparisonOp<AtenGeFloatIntOp,
|
||||||
arith::CmpFPredicate::UGE>>(
|
arith::CmpFPredicate::UGE>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
patterns.add<ConvertAtenFloatComparisonOp<AtenNeFloatIntOp,
|
|
||||||
arith::CmpFPredicate::UNE>>(
|
|
||||||
typeConverter, context);
|
|
||||||
target.addIllegalOp<ValueTensorLiteralOp>();
|
target.addIllegalOp<ValueTensorLiteralOp>();
|
||||||
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);
|
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);
|
||||||
|
|
||||||
|
|
|
@ -498,7 +498,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
|
||||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||||
emit("aten::ge.float_int : (float, int) -> (bool)")
|
emit("aten::ge.float_int : (float, int) -> (bool)")
|
||||||
emit("aten::ne.float_int : (float, int) -> (bool)")
|
|
||||||
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
||||||
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
|
|
|
@ -108,23 +108,3 @@ class GeFloatIntModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: GeFloatIntModule())
|
@register_test_case(module_factory=lambda: GeFloatIntModule())
|
||||||
def GeFloatIntModule_basic(module, tu: TestUtils):
|
def GeFloatIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(()), torch.randint(-100, 100, ()))
|
module.forward(torch.randn(()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class NeFloatIntModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([], torch.float64, True),
|
|
||||||
([], torch.int64, True),
|
|
||||||
])
|
|
||||||
def forward(self, lhs, rhs):
|
|
||||||
return float(lhs) != int(rhs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NeFloatIntModule())
|
|
||||||
def NeFloatIntModule_basic(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randn(()), torch.randint(-100, 100, ()))
|
|
||||||
|
|
|
@ -193,17 +193,3 @@ func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.
|
||||||
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.ne.float_int(
|
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
|
||||||
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
|
||||||
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
|
||||||
// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64
|
|
||||||
// CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64
|
|
||||||
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
|
||||||
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
|
||||||
func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
|
|
||||||
%0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
|
||||||
return %0 : !torch.bool
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue