mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660)
The lowering pattern for `aten.T` uses transposition implemented via `linalg.generic`. For downstream passes it is advantageous to use named ops wherever possible, so this patch changes the lowering to use `linalg.transpose` instead.pull/3692/head
parent
70d5730c87
commit
df6098e43d
|
@ -1795,32 +1795,16 @@ public:
|
|||
|
||||
Value outVector = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outputDims), elementType);
|
||||
SmallVector<AffineExpr> idExprs;
|
||||
SmallVector<AffineExpr> swapExprs;
|
||||
for (auto i = 0; i < inputRank; i++)
|
||||
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
|
||||
for (auto i = 0; i < inputRank; i++) {
|
||||
if (i == dim0)
|
||||
swapExprs.push_back(idExprs[dim1]);
|
||||
else if (i == dim1)
|
||||
swapExprs.push_back(idExprs[dim0]);
|
||||
else
|
||||
swapExprs.push_back(idExprs[i]);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
|
||||
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
inputRank, utils::IteratorType::parallel);
|
||||
auto transpose = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outVector.getType(), inVector, outVector,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
SmallVector<int64_t> permutation(inputRank);
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
permutation[dim0] = dim1;
|
||||
permutation[dim1] = dim0;
|
||||
|
||||
auto transpose =
|
||||
rewriter
|
||||
.create<linalg::TransposeOp>(loc, inVector, outVector, permutation)
|
||||
.getResult();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -339,3 +339,18 @@ func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
|
|||
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.transpose$basic(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[IN_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
||||
// CHECK: %[[TRANSP:.*]] = linalg.transpose ins(%[[IN_0]] : tensor<4x3xf32>) outs(%1 : tensor<3x4xf32>) permutation = [1, 0]
|
||||
// CHECK: %[[OUT_0:.*]] = torch_c.from_builtin_tensor %{{.*}} : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[OUT_0]] : !torch.vtensor<[3,4],f32>
|
||||
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue