mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Lower aten.cat to tensor.concat (#2650)
This replaces the lowering of aten.cat with tensor.concat, allowing more efficient handling of concatenations in downstream flows. The refbackend populates concat decomposition patterns that can be used to recover the previous lowering.pull/2658/head
parent
061af696ce
commit
030b0140d4
|
@ -31,6 +31,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
|
||||
} // namespace RefBackend
|
||||
} // namespace torch
|
||||
|
|
|
@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
|
|||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
|
||||
let summary = "Convert tensor.concat to other tensor ops";
|
||||
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
|
||||
}
|
||||
|
||||
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
|
||||
let summary = "Convert tensor.pad to linalg ops";
|
||||
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
|
||||
|
|
|
@ -1033,9 +1033,12 @@ public:
|
|||
|
||||
auto outElemType = newResultType.getElementType();
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto inputType = cast<RankedTensorType>(tensors[i].getType());
|
||||
if (inputType.getElementType() != outElemType) {
|
||||
tensors[i] = torch_to_linalg::convertTensorToElementType(
|
||||
rewriter, loc, tensors[i], outElemType);
|
||||
}
|
||||
}
|
||||
|
||||
int rank = newResultType.getRank();
|
||||
Value dimValue = op.getDim();
|
||||
|
@ -1046,48 +1049,8 @@ public:
|
|||
if (!isValidDim(dim, rank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
sizes.reserve(rank);
|
||||
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
|
||||
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
||||
|
||||
for (int i = 0; i < rank; ++i)
|
||||
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
|
||||
|
||||
// Calculate the size of the `dim` result dimension by adding the dim size
|
||||
// of each tensor together.
|
||||
Value resultDimSize = sizes[dim];
|
||||
|
||||
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, rewriter.getIndexAttr(dim));
|
||||
for (auto tensor : ArrayRef(tensors).drop_front()) {
|
||||
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
resultDimSize =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
|
||||
}
|
||||
sizes[dim] = resultDimSize;
|
||||
|
||||
auto toOpFoldResult = [](Value v) -> OpFoldResult {
|
||||
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!op)
|
||||
return v;
|
||||
return op.getValue();
|
||||
};
|
||||
|
||||
Value result = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(sizes), newResultType.getElementType());
|
||||
for (auto tensor : tensors) {
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
|
||||
result = rewriter.createOrFold<tensor::InsertSliceOp>(
|
||||
loc, tensor, result,
|
||||
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
|
||||
offsets[dim] =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, newResultType, dim,
|
||||
tensors);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -20,10 +20,12 @@
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
|||
return std::make_unique<MungeMemrefCopy>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GeneralizeTensorConcat
|
||||
: public GeneralizeTensorConcatBase<GeneralizeTensorConcat> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
tensor::populateDecomposeTensorConcatPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::RefBackend::createGeneralizeTensorConcatPass() {
|
||||
return std::make_unique<GeneralizeTensorConcat>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GeneralizeTensorPad
|
||||
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {
|
||||
|
|
|
@ -123,6 +123,7 @@ class RefBackendInvoker:
|
|||
|
||||
LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Apply some optimizations. It would be great if MLIR had more useful
|
||||
# optimizations that worked out of the box here.
|
||||
# Note: When measured, this doesn't seem to actually help that much
|
||||
|
|
|
@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
|
|||
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
|
||||
return %0 : !torch.vtensor<[?,?],f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.cat$convert(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor<?x?xi32>) outs(%{{.*}}: tensor<?x?xf32>)
|
||||
// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
|
||||
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.cat(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
|
||||
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue