diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 62164ceec..b52f0d47a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -44,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = { # RuntimeError: Failed running call_function aten.uniform(... # https://github.com/pytorch/torchdynamo/issues/1954 "UniformNoCorrelationModule_basic", - # TypeError: expected np.ndarray (got float) - # TODO: This is due to returning a scalar float as output from the test. - # We should probably just standardize all tests to return tensors. - "DivIntModule_basic", #### Torch-MLIR internal compiler errors diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 30485ce47..fada9415a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9816,6 +9816,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 90cd0d3e9..426bb750b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2168,6 +2168,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenSubFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a - b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5ea7dfbcd..071eb3fa4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -607,7 +607,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)") - emit("aten::sub.float : (float, float) -> (float)") + emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)") emit("aten::div.float : (float, float) -> (float)", has_folder=True) emit("aten::neg.float : (float) -> (float)") diff --git a/python/torch_mlir_e2e_test/configs/torchdynamo.py b/python/torch_mlir_e2e_test/configs/torchdynamo.py index ee1e35ec7..2b16b1b92 100644 --- a/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -5,6 +5,7 @@ from typing import List +import numpy import torch import torch._dynamo as dynamo import torch_mlir @@ -57,13 +58,18 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, loaded = backend.load(compiled) def compiled_callable(*inputs): + def refine_result_type(_result): + if isinstance(_result, tuple): + return tuple(refine_result_type(x) for x in _result) + elif isinstance(_result, numpy.ndarray): + return torch.from_numpy(_result) + elif isinstance(_result, (bool, int, float)): + return _result + else: + raise ValueError(f"Unhandled return type {type(_result)}") inputs = [x.numpy() for x in inputs] result = loaded.forward(*inputs) - if not isinstance(result, tuple): - result = torch.from_numpy(result) - else: - result = tuple(torch.from_numpy(x) for x in result) - return result + return refine_result_type(result) return compiled_callable diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1aa14a70e..376e3bf18 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2705,6 +2705,27 @@ def IntFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== + +class AtenSubFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.value1 = 1.0 + self.value2 = 2.0 + + @export + @annotate_args([ + None, + ]) + def forward(self): + return float(torch.ops.aten.sub(self.value1, self.value2)) + +@register_test_case(module_factory=lambda: AtenSubFloatModule()) +def AtenSubFloatModule_basic(module, tu: TestUtils): + module.forward() + + # ============================================================================== class ScalarImplicitFloatModule(torch.nn.Module): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 13eed7718..4b487e716 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1686,6 +1686,16 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], return %2 : !torch.vtensor<[],si64> } +// CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float { +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00 +// CHECK: return %[[FLOAT_1]] : !torch.float +func.func @torch.aten.sub.float$fold() -> !torch.float { + %float1 = torch.constant.float 1.0 + %float2 = torch.constant.float 2.0 + %0 = torch.aten.sub.float %float1, %float2 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[INT6]] = torch.constant.int 6 // CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>