diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index cce4aa8a6..d268a31dd 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -18,6 +18,8 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "UnflattenStaticModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -928,6 +930,8 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4f4fa561f..0530e3082 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [ }]; } +def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_FloatType:$rtol, + Torch_FloatType:$atol, + Torch_BoolType:$equal_nan + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIscloseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index d2adefc4d..970ef15d8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIscloseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // check args + double rtol, atol; + bool equalNan; + if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol))) + return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant"); + if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol))) + return rewriter.notifyMatchFailure(op, "atol must be a scalar constant"); + if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan))) + return rewriter.notifyMatchFailure( + op, "unimplemented: equal_nan is expected to be false"); + + // check tensor type. + auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto otherType = adaptor.getOther().getType().dyn_cast(); + if (!selfType || !otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + if (!selfType.getElementType().isa() || + !otherType.getElementType().isa()) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only FP element type is supported"); + } + + auto rhsSubOp = rewriter.create( + op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther()); + auto rhsAbsOp = + rewriter.create(op->getLoc(), selfType, rhsSubOp); + + auto lhsAbsOp = + rewriter.create(op->getLoc(), otherType, adaptor.getOther()); + auto rtolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto mulOp = rewriter.create(op->getLoc(), otherType, + rtolConstOp, lhsAbsOp, /*shift=*/0); + auto atolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); + auto addOp = + rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, addOp, + rhsAbsOp); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, @@ -5134,6 +5187,7 @@ public: INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 513d7b018..47d76219f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 00e752f01..958df70d5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -844,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]: return upstream_shape_functions.unsqueeze(self, dim) @@ -2171,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int: + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: _, query_dtype = query_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 3916f3136..e473603ff 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_real : (Tensor) -> (Tensor)") + emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index e0269e68c..d78253a58 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4580,3 +4580,48 @@ class Add_Module(torch.nn.Module): @register_test_case(module_factory=lambda: Add_Module()) def Add_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) + + +# ============================================================================== + + +class IscloseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ([5, 5], torch.float32, True), + ]) + def forward(self, x, y): + return torch.isclose(x, y) + + +@register_test_case(module_factory=lambda: IscloseStaticModule()) +def IscloseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5), tu.rand(5, 5)) + + +# ============================================================================== + + +class IscloseStaticModuleTrue(torch.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('tensor', torch.ones(1)) + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.isclose(x, self.tensor) + +@register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) +def IscloseStaticModuleTrue_basic(module, tu: TestUtils): + module.forward(torch.ones(5, 5)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index dc4e4793a..46023598c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to %0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32> } + +// ----- + +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08 +// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { + %float1.000000e-08 = torch.constant.float 1.000000e-08 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %false = torch.constant.bool false + %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> + return %0 : !torch.vtensor<[5,5],i1> +}