[LINALG] Fix linalg generic result type argument in TorchToLinalg (#588)

Some of the lowerings use the result type obtained from the op itself
to tell the `linalg::GenericOp` what the type of the result should be
rather than using the type of the result tensor given to the
`linalg::GenericOp`. This becomes a problem when the result type of
the op has static size information and the result tensor used in
`linalg::GenericOp` has dynamic dimensions, for `linalg::GenericOp`
expects the result type to be equal to the type of the output tensor.

This commit replaces the use of the result type from the op itself
with the type of the result tensor passed to `linalg::GenericOp`.

In order to not create too many dynamic/static versions of the same
e2e test, e2e tests have only been added to the ops that currently
fail when used with static sizes.
pull/591/head
Ramiro Leal-Cavazos 2022-02-11 19:42:18 -08:00 committed by GitHub
parent 73ac9a7e2e
commit 3dc7847348
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 12 deletions

View File

@ -480,6 +480,26 @@ def GatherModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class GatherStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 4], torch.float32, True),
([1, 2, 3], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)
@register_test_case(module_factory=lambda: GatherStaticModule())
def GatherStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
# ==============================================================================
class AddSizeIntModule(torch.nn.Module): class AddSizeIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class UniformModule(torch.nn.Module): class UniformModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -44,6 +45,44 @@ def UniformModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class UniformStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([256, 512, 8], torch.float64, True),
([512, 1024, 4], torch.float64, True),
([512, 256, 4], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
return std, mean
@register_test_case(module_factory=lambda: UniformStaticModule())
def UniformStaticModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class BernoulliModule(torch.nn.Module): class BernoulliModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1129,7 +1129,7 @@ public:
Value finalRes = Value finalRes =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, newResultType, ValueRange{lhs, rhs}, initTensor0, loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0,
/*indexingMaps=*/indexingMaps, /*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
@ -1293,7 +1293,7 @@ public:
Value finalRes = Value finalRes =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, resultType, ValueRange{target}, initTensor0, loc, initTensor0.getType(), ValueRange{target}, initTensor0,
/*indexingMaps=*/indexingMaps, /*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
@ -1395,7 +1395,8 @@ public:
Value finalRes = Value finalRes =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, resultType, ValueRange{target, gradOutput}, initTensor0, loc, initTensor0.getType(), ValueRange{target, gradOutput},
initTensor0,
/*indexingMaps=*/indexingMaps, /*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
@ -3702,14 +3703,17 @@ public:
SmallVector<AffineMap, 2> affineMaps(2, SmallVector<AffineMap, 2> affineMaps(2,
rewriter.getMultiDimIdentityMap(rank)); rewriter.getMultiDimIdentityMap(rank));
SmallVector<StringRef> iteratorTypes(rank, getParallelIteratorTypeName()); SmallVector<StringRef> iteratorTypes(rank, getParallelIteratorTypeName());
auto genericOp = rewriter.create<linalg::GenericOp>( auto genericOp = rewriter
loc, newResultTy, indices, result, affineMaps, iteratorTypes, .create<linalg::GenericOp>(
loc, result.getType(), indices, result, affineMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
auto index = args[0]; auto index = args[0];
createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index, createLinalgPayloadCalculationForGatherOps(
dim, rank); b, loc, self, rank, index, dim, rank);
}); })
rewriter.replaceOp(op, genericOp.getResult(0)); .getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultTy, genericOp);
return success(); return success();
} }
}; };
@ -4481,7 +4485,7 @@ public:
Value uniformRes = Value uniformRes =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, resultType, /*inputs=*/ValueRange{}, loc, initTensor.getType(), /*inputs=*/ValueRange{},
/*outputs=*/initTensor, indexingMaps, iteratorTypes, /*outputs=*/initTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value temp = initialSeed; Value temp = initialSeed;