[TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660)

The lowering pattern for `aten.T` uses transposition implemented via
`linalg.generic`. For downstream passes it is advantageous to use named
ops wherever possible, so this patch changes the lowering to use
`linalg.transpose` instead.
pull/3692/head
Felix Schneider 2024-09-07 08:09:10 +02:00 committed by GitHub
parent 70d5730c87
commit df6098e43d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 25 deletions

View File

@ -1795,32 +1795,16 @@ public:
Value outVector = rewriter.create<tensor::EmptyOp>( Value outVector = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType); loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (auto i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (auto i = 0; i < inputRank; i++) {
if (i == dim0)
swapExprs.push_back(idExprs[dim1]);
else if (i == dim1)
swapExprs.push_back(idExprs[dim0]);
else
swapExprs.push_back(idExprs[i]);
}
SmallVector<AffineMap> indexingMaps = { SmallVector<int64_t> permutation(inputRank);
AffineMap::get(inputRank, 0, idExprs, op.getContext()), std::iota(permutation.begin(), permutation.end(), 0);
AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; permutation[dim0] = dim1;
SmallVector<utils::IteratorType> iteratorTypes( permutation[dim1] = dim0;
inputRank, utils::IteratorType::parallel);
auto transpose = rewriter auto transpose =
.create<linalg::GenericOp>( rewriter
loc, outVector.getType(), inVector, outVector, .create<linalg::TransposeOp>(loc, inVector, outVector, permutation)
indexingMaps, iteratorTypes, .getResult();
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
return success(); return success();
} }

View File

@ -339,3 +339,18 @@ func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32> %1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.transpose$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[IN_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
// CHECK: %[[TRANSP:.*]] = linalg.transpose ins(%[[IN_0]] : tensor<4x3xf32>) outs(%1 : tensor<3x4xf32>) permutation = [1, 0]
// CHECK: %[[OUT_0:.*]] = torch_c.from_builtin_tensor %{{.*}} : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[OUT_0]] : !torch.vtensor<[3,4],f32>
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}