[Torch] Fix mixP case for non value semantic ops (#2540)

NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the
same as the first input. However, current implementation would result in
wrong result type for case like:

```python
a = torch.randn(3, 3).half() # float16
b = torch.randn(3, 3) # float32
a += b # i.e. torch.ops.aten.add_(a, b)
```
torch expects `a` to be float16, but dtype refinement would infer
float32 type, since it's replaced by `aten.add`.
pull/2545/head snapshot-20231102.1010
Zhekun(Josh) Zhang 2023-11-02 12:40:08 +08:00 committed by GitHub
parent 4901773f77
commit 88d4c475d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 3 deletions

View File

@ -243,8 +243,20 @@ public:
"Torch JIT operators shouldn't have regions or successors");
Operation *newOp = rewriter.create(state);
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
// Note: need to convert result to first input's dtype because mix precision
// compute would result in different behaviors.
// For example:
// a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
Value cstFalse = rewriter.create<ConstantBoolOp>(op->getLoc(), false);
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
auto toDtype = rewriter.create<AtenToDtypeOp>(
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));

View File

@ -1012,6 +1012,30 @@ def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class Add_MixPModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float64, True),
])
def forward(self, a, b):
a += b
return a
@register_test_case(module_factory=lambda: Add_MixPModule())
def Add_MixPModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3), tu.rand(3, 3).double())
# ==============================================================================
class EmbeddingModuleI64(torch.nn.Module):
def __init__(self):

View File

@ -94,7 +94,11 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// (which is cleaned up by canonicalization) is an artifact of two patterns
// being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
// CHECK: %[[DTYPE_RESULT:.*]] = torch.aten.to.dtype %[[ARRAY_RESULT]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[DTYPE_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {