mirror of https://github.com/llvm/torch-mlir
[Torch] Fix mixP case for non value semantic ops (#2540)
NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the same as the first input. However, current implementation would result in wrong result type for case like: ```python a = torch.randn(3, 3).half() # float16 b = torch.randn(3, 3) # float32 a += b # i.e. torch.ops.aten.add_(a, b) ``` torch expects `a` to be float16, but dtype refinement would infer float32 type, since it's replaced by `aten.add`.pull/2545/head snapshot-20231102.1010
parent
4901773f77
commit
88d4c475d3
|
@ -243,8 +243,20 @@ public:
|
|||
"Torch JIT operators shouldn't have regions or successors");
|
||||
|
||||
Operation *newOp = rewriter.create(state);
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
// Note: need to convert result to first input's dtype because mix precision
|
||||
// compute would result in different behaviors.
|
||||
// For example:
|
||||
// a = torch.randn(3, 3).half() # float16
|
||||
// b = torch.randn(3, 3) # float32
|
||||
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
|
||||
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
|
||||
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
||||
Value cstFalse = rewriter.create<ConstantBoolOp>(op->getLoc(), false);
|
||||
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
|
||||
auto toDtype = rewriter.create<AtenToDtypeOp>(
|
||||
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
|
||||
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
||||
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
|
||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||
op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
|
|
|
@ -1012,6 +1012,30 @@ def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class Add_MixPModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
a += b
|
||||
return a
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Add_MixPModule())
|
||||
def Add_MixPModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3), tu.rand(3, 3).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class EmbeddingModuleI64(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -94,7 +94,11 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
|
|||
// (which is cleaned up by canonicalization) is an artifact of two patterns
|
||||
// being applied in sequence.
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[DTYPE_RESULT:.*]] = torch.aten.to.dtype %[[ARRAY_RESULT]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[DTYPE_RESULT]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
|
|
Loading…
Reference in New Issue