[TorchToLinalg] Lower aten.cat to tensor.concat (#2650)

This replaces the lowering of aten.cat with tensor.concat, allowing more
efficient handling of concatenations in downstream flows. The refbackend
populates concat decomposition patterns that can be used to recover the
previous lowering.
pull/2658/head
Quinn Dawkins 2023-12-15 15:45:32 -05:00 committed by GitHub
parent 061af696ce
commit 030b0140d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 45 deletions

View File

@ -31,6 +31,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend
} // namespace torch

View File

@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let dependentDialects = ["memref::MemRefDialect"];
}
def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
let summary = "Convert tensor.concat to other tensor ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
}
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";

View File

@ -1033,9 +1033,12 @@ public:
auto outElemType = newResultType.getElementType();
for (size_t i = 0; i < tensors.size(); ++i) {
auto inputType = cast<RankedTensorType>(tensors[i].getType());
if (inputType.getElementType() != outElemType) {
tensors[i] = torch_to_linalg::convertTensorToElementType(
rewriter, loc, tensors[i], outElemType);
}
}
int rank = newResultType.getRank();
Value dimValue = op.getDim();
@ -1046,48 +1049,8 @@ public:
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : ArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
Value result = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(sizes), newResultType.getElementType());
for (auto tensor : tensors) {
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, tensor, result,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
offsets[dim] =
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, newResultType, dim,
tensors);
return success();
}
};

View File

@ -20,10 +20,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>();
}
namespace {
class GeneralizeTensorConcat
: public GeneralizeTensorConcatBase<GeneralizeTensorConcat> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
tensor::populateDecomposeTensorConcatPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createGeneralizeTensorConcatPass() {
return std::make_unique<GeneralizeTensorConcat>();
}
namespace {
class GeneralizeTensorPad
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {

View File

@ -123,6 +123,7 @@ class RefBackendInvoker:
LOWERING_PIPELINE = "builtin.module(" + ",".join([
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much

View File

@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat$convert(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor<?x?xi32>) outs(%{{.*}}: tensor<?x?xf32>)
// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %int0 = torch.constant.int 0
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}