From 3dc78473482a05286ffdb8a8a20d3ab43068012b Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 11 Feb 2022 19:42:18 -0800 Subject: [PATCH] [LINALG] Fix linalg generic result type argument in TorchToLinalg (#588) Some of the lowerings use the result type obtained from the op itself to tell the `linalg::GenericOp` what the type of the result should be rather than using the type of the result tensor given to the `linalg::GenericOp`. This becomes a problem when the result type of the op has static size information and the result tensor used in `linalg::GenericOp` has dynamic dimensions, for `linalg::GenericOp` expects the result type to be equal to the type of the output tensor. This commit replaces the use of the result type from the op itself with the type of the result tensor passed to `linalg::GenericOp`. In order to not create too many dynamic/static versions of the same e2e test, e2e tests have only been added to the ops that currently fail when used with static sizes. --- e2e_testing/torchscript/basic.py | 20 ++++++++++ e2e_testing/torchscript/rng.py | 39 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 28 +++++++------ 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index d41246634..d7d140c69 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -480,6 +480,26 @@ def GatherModule_basic(module, tu: TestUtils): # ============================================================================== +class GatherStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ([1, 2, 3], torch.int64, True), + ]) + def forward(self, tensor, indices): + return torch.gather(tensor, 2, indices) + + +@register_test_case(module_factory=lambda: GatherStaticModule()) +def GatherStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]])) + +# ============================================================================== + class AddSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/rng.py b/e2e_testing/torchscript/rng.py index e9d21fb8c..68b280358 100644 --- a/e2e_testing/torchscript/rng.py +++ b/e2e_testing/torchscript/rng.py @@ -6,6 +6,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class UniformModule(torch.nn.Module): def __init__(self): @@ -44,6 +45,44 @@ def UniformModule_basic(module, tu: TestUtils): # ============================================================================== +class UniformStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([256, 512, 8], torch.float64, True), + ([512, 1024, 4], torch.float64, True), + ([512, 256, 4], torch.float64, True), + ]) + def forward(self, x, y, z): + a = torch.ops.aten.uniform_(x, 1.0, 10.0) + b = torch.ops.aten.uniform_(y, -20.0, -5.0) + c = torch.ops.aten.uniform_(z, -15.0, 3.0) + std = torch.cat([ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)) + ]) + mean = torch.cat([ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)) + ]) + return std, mean + + +@register_test_case(module_factory=lambda: UniformStaticModule()) +def UniformStaticModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(256, 512, 8).double(), + tu.rand(512, 1024, 4).double(), + tu.rand(512, 256, 4).double()) + +# ============================================================================== + class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 9f6c911a9..20ebc2db8 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1129,7 +1129,7 @@ public: Value finalRes = rewriter .create( - loc, newResultType, ValueRange{lhs, rhs}, initTensor0, + loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -1293,7 +1293,7 @@ public: Value finalRes = rewriter .create( - loc, resultType, ValueRange{target}, initTensor0, + loc, initTensor0.getType(), ValueRange{target}, initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -1395,7 +1395,8 @@ public: Value finalRes = rewriter .create( - loc, resultType, ValueRange{target, gradOutput}, initTensor0, + loc, initTensor0.getType(), ValueRange{target, gradOutput}, + initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -3702,14 +3703,17 @@ public: SmallVector affineMaps(2, rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes(rank, getParallelIteratorTypeName()); - auto genericOp = rewriter.create( - loc, newResultTy, indices, result, affineMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - auto index = args[0]; - createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index, - dim, rank); - }); - rewriter.replaceOp(op, genericOp.getResult(0)); + auto genericOp = rewriter + .create( + loc, result.getType(), indices, result, affineMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto index = args[0]; + createLinalgPayloadCalculationForGatherOps( + b, loc, self, rank, index, dim, rank); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultTy, genericOp); return success(); } }; @@ -4481,7 +4485,7 @@ public: Value uniformRes = rewriter .create( - loc, resultType, /*inputs=*/ValueRange{}, + loc, initTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/initTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value temp = initialSeed;