mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix linalg generic result type argument in TorchToLinalg (#588)
Some of the lowerings use the result type obtained from the op itself to tell the `linalg::GenericOp` what the type of the result should be rather than using the type of the result tensor given to the `linalg::GenericOp`. This becomes a problem when the result type of the op has static size information and the result tensor used in `linalg::GenericOp` has dynamic dimensions, for `linalg::GenericOp` expects the result type to be equal to the type of the output tensor. This commit replaces the use of the result type from the op itself with the type of the result tensor passed to `linalg::GenericOp`. In order to not create too many dynamic/static versions of the same e2e test, e2e tests have only been added to the ops that currently fail when used with static sizes.pull/591/head
parent
73ac9a7e2e
commit
3dc7847348
|
@ -480,6 +480,26 @@ def GatherModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class GatherStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
([1, 2, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GatherStaticModule())
|
||||
def GatherStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AddSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -44,6 +45,44 @@ def UniformModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([256, 512, 8], torch.float64, True),
|
||||
([512, 1024, 4], torch.float64, True),
|
||||
([512, 256, 4], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
return std, mean
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UniformStaticModule())
|
||||
def UniformStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1129,7 +1129,7 @@ public:
|
|||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, newResultType, ValueRange{lhs, rhs}, initTensor0,
|
||||
loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
|
@ -1293,7 +1293,7 @@ public:
|
|||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, ValueRange{target}, initTensor0,
|
||||
loc, initTensor0.getType(), ValueRange{target}, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
|
@ -1395,7 +1395,8 @@ public:
|
|||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, ValueRange{target, gradOutput}, initTensor0,
|
||||
loc, initTensor0.getType(), ValueRange{target, gradOutput},
|
||||
initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
|
@ -3702,14 +3703,17 @@ public:
|
|||
SmallVector<AffineMap, 2> affineMaps(2,
|
||||
rewriter.getMultiDimIdentityMap(rank));
|
||||
SmallVector<StringRef> iteratorTypes(rank, getParallelIteratorTypeName());
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, newResultTy, indices, result, affineMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
auto index = args[0];
|
||||
createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index,
|
||||
dim, rank);
|
||||
});
|
||||
rewriter.replaceOp(op, genericOp.getResult(0));
|
||||
auto genericOp = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, result.getType(), indices, result, affineMaps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
auto index = args[0];
|
||||
createLinalgPayloadCalculationForGatherOps(
|
||||
b, loc, self, rank, index, dim, rank);
|
||||
})
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultTy, genericOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -4481,7 +4485,7 @@ public:
|
|||
Value uniformRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, /*inputs=*/ValueRange{},
|
||||
loc, initTensor.getType(), /*inputs=*/ValueRange{},
|
||||
/*outputs=*/initTensor, indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value temp = initialSeed;
|
||||
|
|
Loading…
Reference in New Issue