mirror of https://github.com/llvm/torch-mlir
parent
50ad5eab9a
commit
7a8304f935
|
@ -44,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# RuntimeError: Failed running call_function aten.uniform(...
|
||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||
"UniformNoCorrelationModule_basic",
|
||||
# TypeError: expected np.ndarray (got float)
|
||||
# TODO: This is due to returning a scalar float as output from the test.
|
||||
# We should probably just standardize all tests to return tensors.
|
||||
"DivIntModule_basic",
|
||||
|
||||
#### Torch-MLIR internal compiler errors
|
||||
|
||||
|
|
|
@ -9816,6 +9816,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
||||
|
|
|
@ -2168,6 +2168,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(), [](double a, double b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -607,7 +607,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
emit("aten::sub.float : (float, float) -> (float)")
|
||||
emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::neg.float : (float) -> (float)")
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
import torch_mlir
|
||||
|
@ -57,13 +58,18 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
|||
loaded = backend.load(compiled)
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
def refine_result_type(_result):
|
||||
if isinstance(_result, tuple):
|
||||
return tuple(refine_result_type(x) for x in _result)
|
||||
elif isinstance(_result, numpy.ndarray):
|
||||
return torch.from_numpy(_result)
|
||||
elif isinstance(_result, (bool, int, float)):
|
||||
return _result
|
||||
else:
|
||||
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = loaded.forward(*inputs)
|
||||
if not isinstance(result, tuple):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
return result
|
||||
return refine_result_type(result)
|
||||
return compiled_callable
|
||||
|
||||
|
||||
|
|
|
@ -2705,6 +2705,27 @@ def IntFloatModule_basic(module, tu: TestUtils):
|
|||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenSubFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.value1 = 1.0
|
||||
self.value2 = 2.0
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return float(torch.ops.aten.sub(self.value1, self.value2))
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenSubFloatModule())
|
||||
def AtenSubFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScalarImplicitFloatModule(torch.nn.Module):
|
||||
|
|
|
@ -1686,6 +1686,16 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],
|
|||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float {
|
||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00
|
||||
// CHECK: return %[[FLOAT_1]] : !torch.float
|
||||
func.func @torch.aten.sub.float$fold() -> !torch.float {
|
||||
%float1 = torch.constant.float 1.0
|
||||
%float2 = torch.constant.float 2.0
|
||||
%0 = torch.aten.sub.float %float1, %float2 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INT6]] = torch.constant.int 6
|
||||
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
|
|
Loading…
Reference in New Issue