mirror of https://github.com/llvm/torch-mlir
parent
50ad5eab9a
commit
7a8304f935
|
@ -44,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# RuntimeError: Failed running call_function aten.uniform(...
|
# RuntimeError: Failed running call_function aten.uniform(...
|
||||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||||
"UniformNoCorrelationModule_basic",
|
"UniformNoCorrelationModule_basic",
|
||||||
# TypeError: expected np.ndarray (got float)
|
|
||||||
# TODO: This is due to returning a scalar float as output from the test.
|
|
||||||
# We should probably just standardize all tests to return tensors.
|
|
||||||
"DivIntModule_basic",
|
|
||||||
|
|
||||||
#### Torch-MLIR internal compiler errors
|
#### Torch-MLIR internal compiler errors
|
||||||
|
|
||||||
|
|
|
@ -9816,6 +9816,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
||||||
|
|
|
@ -2168,6 +2168,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenSubFloatOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
|
||||||
|
return atenBinaryFloatOperatorFoldHelper(
|
||||||
|
adaptor.getOperands(), [](double a, double b) { return a - b; });
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSubOp
|
// AtenSubOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -607,7 +607,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||||
emit("aten::log.int : (int) -> (float)")
|
emit("aten::log.int : (int) -> (float)")
|
||||||
emit("aten::add.float_int : (float, int) -> (float)")
|
emit("aten::add.float_int : (float, int) -> (float)")
|
||||||
emit("aten::sub.float : (float, float) -> (float)")
|
emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
|
||||||
emit("aten::mul.float : (float, float) -> (float)")
|
emit("aten::mul.float : (float, float) -> (float)")
|
||||||
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
||||||
emit("aten::neg.float : (float) -> (float)")
|
emit("aten::neg.float : (float) -> (float)")
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo as dynamo
|
import torch._dynamo as dynamo
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
@ -57,13 +58,18 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||||
loaded = backend.load(compiled)
|
loaded = backend.load(compiled)
|
||||||
|
|
||||||
def compiled_callable(*inputs):
|
def compiled_callable(*inputs):
|
||||||
|
def refine_result_type(_result):
|
||||||
|
if isinstance(_result, tuple):
|
||||||
|
return tuple(refine_result_type(x) for x in _result)
|
||||||
|
elif isinstance(_result, numpy.ndarray):
|
||||||
|
return torch.from_numpy(_result)
|
||||||
|
elif isinstance(_result, (bool, int, float)):
|
||||||
|
return _result
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||||
inputs = [x.numpy() for x in inputs]
|
inputs = [x.numpy() for x in inputs]
|
||||||
result = loaded.forward(*inputs)
|
result = loaded.forward(*inputs)
|
||||||
if not isinstance(result, tuple):
|
return refine_result_type(result)
|
||||||
result = torch.from_numpy(result)
|
|
||||||
else:
|
|
||||||
result = tuple(torch.from_numpy(x) for x in result)
|
|
||||||
return result
|
|
||||||
return compiled_callable
|
return compiled_callable
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2705,6 +2705,27 @@ def IntFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenSubFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.value1 = 1.0
|
||||||
|
self.value2 = 2.0
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return float(torch.ops.aten.sub(self.value1, self.value2))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenSubFloatModule())
|
||||||
|
def AtenSubFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ScalarImplicitFloatModule(torch.nn.Module):
|
class ScalarImplicitFloatModule(torch.nn.Module):
|
||||||
|
|
|
@ -1686,6 +1686,16 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],
|
||||||
return %2 : !torch.vtensor<[],si64>
|
return %2 : !torch.vtensor<[],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float {
|
||||||
|
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00
|
||||||
|
// CHECK: return %[[FLOAT_1]] : !torch.float
|
||||||
|
func.func @torch.aten.sub.float$fold() -> !torch.float {
|
||||||
|
%float1 = torch.constant.float 1.0
|
||||||
|
%float2 = torch.constant.float 2.0
|
||||||
|
%0 = torch.aten.sub.float %float1, %float2 : !torch.float, !torch.float -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[INT6]] = torch.constant.int 6
|
// CHECK: %[[INT6]] = torch.constant.int 6
|
||||||
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
|
Loading…
Reference in New Issue