[Torch Dialect] add folder for aten.sub.float (#1871)

pull/1911/head snapshot-20230303.766
Yuanqiang Liu 2023-03-03 01:07:33 +08:00 committed by GitHub
parent 50ad5eab9a
commit 7a8304f935
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 10 deletions

View File

@ -44,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = {
# RuntimeError: Failed running call_function aten.uniform(...
# https://github.com/pytorch/torchdynamo/issues/1954
"UniformNoCorrelationModule_basic",
# TypeError: expected np.ndarray (got float)
# TODO: This is due to returning a scalar float as output from the test.
# We should probably just standardize all tests to return tensors.
"DivIntModule_basic",
#### Torch-MLIR internal compiler errors

View File

@ -9816,6 +9816,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [

View File

@ -2168,6 +2168,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSubFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenSubOp
//===----------------------------------------------------------------------===//

View File

@ -607,7 +607,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::sub.float : (float, float) -> (float)")
emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
emit("aten::neg.float : (float) -> (float)")

View File

@ -5,6 +5,7 @@
from typing import List
import numpy
import torch
import torch._dynamo as dynamo
import torch_mlir
@ -57,13 +58,18 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
loaded = backend.load(compiled)
def compiled_callable(*inputs):
def refine_result_type(_result):
if isinstance(_result, tuple):
return tuple(refine_result_type(x) for x in _result)
elif isinstance(_result, numpy.ndarray):
return torch.from_numpy(_result)
elif isinstance(_result, (bool, int, float)):
return _result
else:
raise ValueError(f"Unhandled return type {type(_result)}")
inputs = [x.numpy() for x in inputs]
result = loaded.forward(*inputs)
if not isinstance(result, tuple):
result = torch.from_numpy(result)
else:
result = tuple(torch.from_numpy(x) for x in result)
return result
return refine_result_type(result)
return compiled_callable

View File

@ -2705,6 +2705,27 @@ def IntFloatModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenSubFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.value1 = 1.0
self.value2 = 2.0
@export
@annotate_args([
None,
])
def forward(self):
return float(torch.ops.aten.sub(self.value1, self.value2))
@register_test_case(module_factory=lambda: AtenSubFloatModule())
def AtenSubFloatModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ScalarImplicitFloatModule(torch.nn.Module):

View File

@ -1686,6 +1686,16 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float {
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00
// CHECK: return %[[FLOAT_1]] : !torch.float
func.func @torch.aten.sub.float$fold() -> !torch.float {
%float1 = torch.constant.float 1.0
%float2 = torch.constant.float 2.0
%0 = torch.aten.sub.float %float1, %float2 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6]] = torch.constant.int 6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>