mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660)
The lowering pattern for `aten.T` uses transposition implemented via `linalg.generic`. For downstream passes it is advantageous to use named ops wherever possible, so this patch changes the lowering to use `linalg.transpose` instead.pull/3692/head
parent
70d5730c87
commit
df6098e43d
|
@ -1795,32 +1795,16 @@ public:
|
||||||
|
|
||||||
Value outVector = rewriter.create<tensor::EmptyOp>(
|
Value outVector = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(outputDims), elementType);
|
loc, getAsOpFoldResult(outputDims), elementType);
|
||||||
SmallVector<AffineExpr> idExprs;
|
|
||||||
SmallVector<AffineExpr> swapExprs;
|
|
||||||
for (auto i = 0; i < inputRank; i++)
|
|
||||||
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
|
|
||||||
for (auto i = 0; i < inputRank; i++) {
|
|
||||||
if (i == dim0)
|
|
||||||
swapExprs.push_back(idExprs[dim1]);
|
|
||||||
else if (i == dim1)
|
|
||||||
swapExprs.push_back(idExprs[dim0]);
|
|
||||||
else
|
|
||||||
swapExprs.push_back(idExprs[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<AffineMap> indexingMaps = {
|
SmallVector<int64_t> permutation(inputRank);
|
||||||
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
|
std::iota(permutation.begin(), permutation.end(), 0);
|
||||||
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
|
permutation[dim0] = dim1;
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
permutation[dim1] = dim0;
|
||||||
inputRank, utils::IteratorType::parallel);
|
|
||||||
auto transpose = rewriter
|
auto transpose =
|
||||||
.create<linalg::GenericOp>(
|
rewriter
|
||||||
loc, outVector.getType(), inVector, outVector,
|
.create<linalg::TransposeOp>(loc, inVector, outVector, permutation)
|
||||||
indexingMaps, iteratorTypes,
|
.getResult();
|
||||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
b.create<linalg::YieldOp>(loc, args[0]);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -339,3 +339,18 @@ func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
|
||||||
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
|
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
return %1 : !torch.vtensor<[?,?],f32>
|
return %1 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.transpose$basic(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
// CHECK: %[[IN_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
||||||
|
// CHECK: %[[TRANSP:.*]] = linalg.transpose ins(%[[IN_0]] : tensor<4x3xf32>) outs(%1 : tensor<3x4xf32>) permutation = [1, 0]
|
||||||
|
// CHECK: %[[OUT_0:.*]] = torch_c.from_builtin_tensor %{{.*}} : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: return %[[OUT_0]] : !torch.vtensor<[3,4],f32>
|
||||||
|
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue