mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support aten.isinf (#2544)
Also fix linalg lowering from `UEQ` to `OEQ`. I will check other comparison's lowering later.pull/2554/head
parent
88adf384cc
commit
0378da0abd
|
@ -8024,6 +8024,29 @@ def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::isinf : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIsinfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenIsinfOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAllOp : Torch_Op<"aten.all", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -89,7 +89,7 @@ static Value createLessThanOrEqual(OpBuilder &b, Location loc,
|
|||
|
||||
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||
return createComparisonTemplate<arith::CmpFPredicate::OEQ,
|
||||
arith::CmpIPredicate::eq,
|
||||
arith::CmpIPredicate::eq>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
|
|
|
@ -6507,6 +6507,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.isinf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9186,6 +9190,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.isinf\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
|
|
|
@ -530,6 +530,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenIsinfOp : public OpRewritePattern<AtenIsinfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIsinfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
|
||||
mlir::FloatType f64Type = rewriter.getF64Type();
|
||||
Value inf = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getFloatAttr(
|
||||
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
|
||||
Value abs = rewriter.create<AtenAbsOp>(loc, self.getType(), self);
|
||||
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), abs, inf);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||
public:
|
||||
|
@ -5458,6 +5478,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
|
|
|
@ -426,6 +426,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenEyeOp>();
|
||||
target.addIllegalOp<AtenEyeMOp>();
|
||||
target.addIllegalOp<AtenIsnanOp>();
|
||||
target.addIllegalOp<AtenIsinfOp>();
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
target.addIllegalOp<AtenRelu6Op>();
|
||||
|
|
|
@ -1066,6 +1066,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseIsnanModule_basic",
|
||||
"ElementwiseIsinfModule_basic",
|
||||
"TypePromotionAlphaWiderModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||
|
@ -1460,4 +1461,5 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseIsinfModule_basic",
|
||||
}
|
||||
|
|
|
@ -243,6 +243,9 @@ def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: Li
|
|||
def aten〇isnan〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇isinf〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
|
@ -2236,6 +2239,10 @@ def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
|
|
@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
|
||||
emit("aten::isnan : (Tensor) -> (Tensor)")
|
||||
emit("aten::isinf : (Tensor) -> (Tensor)")
|
||||
emit("aten::all : (Tensor) -> (Tensor)")
|
||||
emit("aten::all.bool : (bool[]) -> (bool)")
|
||||
emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)")
|
||||
|
|
|
@ -454,7 +454,7 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
||||
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -534,7 +534,7 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
|
||||
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -575,7 +575,7 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
||||
def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -765,7 +765,7 @@ class ElementwiseIsnanModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.isnan(x)
|
||||
|
@ -773,5 +773,25 @@ class ElementwiseIsnanModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseIsnanModule())
|
||||
def ElementwiseIsnanModule_basic(module, tu: TestUtils):
|
||||
x = torch.full((1, 1, 32), torch.nan)
|
||||
x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf])
|
||||
module.forward(x)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseIsinfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.isinf(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseIsinfModule())
|
||||
def ElementwiseIsinfModule_basic(module, tu: TestUtils):
|
||||
x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf])
|
||||
module.forward(x)
|
||||
|
|
Loading…
Reference in New Issue