mirror of https://github.com/llvm/torch-mlir
Add lowerings for tm_tensor.npbroadcast
This shuffles the existing broadcasting pattern to an opt-in TMTensor pass to lower tm_tensor.npbroadcast ops to a pessimistic linalg equivalent. This maintains the semantics of a numpy broadcast past the torch lowering stage. Additionally relocates the lowering for ops that used numpy broadcasting semantics (`aten.matmul`, `aten.broadcast_to`, and `aten.copy`) to the conversion to TMTensor.numpy_style_broadcast
parent
65bc15b340
commit
a7f506adc4
|
@ -337,16 +337,15 @@ def TMTensor_NumpyBroadcastOp : TMTensor_Op<"npbroadcast",
|
|||
let description = [{
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$outputs
|
||||
Variadic<AnyShaped>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensorOrMemRefType>:$results);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let results = (outs Variadic<AnyShaped>:$results);
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
`ins` `(` $inputs `:` type($inputs) `)`
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
$region (`->` type($results)^)?
|
||||
(`->` type($results)^)?
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
|
|
|
@ -17,6 +17,8 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createTMTensorBroadcastToLinalgPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();
|
||||
|
||||
|
|
|
@ -12,6 +12,12 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TMTensorBroadcastToLinalg :
|
||||
Pass<"tm-tensor-broadcast-to-linalg", "func::FuncOp"> {
|
||||
let summary = "Convert TMTensor NumpyBroadcastOps to linalg.";
|
||||
let constructor = "mlir::torch::TMTensor::createTMTensorBroadcastToLinalgPass()";
|
||||
}
|
||||
|
||||
def TMTensorToLoops :
|
||||
Pass<"tm-tensor-to-loops", "func::FuncOp"> {
|
||||
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||
|
|
|
@ -884,11 +884,6 @@ LogicalResult NumpyBroadcastOp::verify() {
|
|||
if (getNumOutputs() != 1) {
|
||||
return emitOpError("expected one output operand");
|
||||
}
|
||||
|
||||
if (getInputRank() != getOutputRank()) {
|
||||
return emitOpError("expected input and output ranks to be the same");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_library(TorchMLIRTMTensorPasses
|
||||
ConvertBroadcastToLinalg.cpp
|
||||
ConvertToLoops.cpp
|
||||
Bufferize.cpp
|
||||
Passes.cpp
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
||||
namespace {
|
||||
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(NumpyBroadcastOp broadcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = broadcastOp.getLoc();
|
||||
Value input = broadcastOp.getInput();
|
||||
Value output = broadcastOp.getOutput();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto outputType = output.getType().cast<RankedTensorType>();
|
||||
int64_t diff = outputType.getRank() - inputType.getRank();
|
||||
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
Value zeroIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value oneIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
rewriter.getMultiDimIdentityMap(outputType.getRank())};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outputType.getRank(), utils::IteratorType::parallel);
|
||||
rewriter
|
||||
.replaceOpWithNewOp<linalg::GenericOp>(
|
||||
broadcastOp, output.getType(), ValueRange(), output, indexingMaps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
// `loopIndices` contains IV of the linalg loops which
|
||||
// would be used to extract values from the input tensor
|
||||
// later on.
|
||||
SmallVector<Value> loopIndices;
|
||||
for (int64_t i = diff; i < outputType.getRank(); ++i) {
|
||||
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
}
|
||||
// `inputIndicesToExtract` contains i-th linalg loop IV if
|
||||
// the i-th input dimension is not 1, else it contains a
|
||||
// zero index.
|
||||
SmallVector<Value> inputIndicesToExtract;
|
||||
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
|
||||
if (inputShape[i] == 1 && outputShape[i + diff] != 1) {
|
||||
inputIndicesToExtract.push_back(zeroIndex);
|
||||
} else {
|
||||
Value inputDim = b.createOrFold<tensor::DimOp>(loc, input, i);
|
||||
Value isEqual = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
|
||||
Value select = b.create<arith::SelectOp>(
|
||||
loc, isEqual, zeroIndex, loopIndices[i]);
|
||||
inputIndicesToExtract.push_back(select);
|
||||
}
|
||||
}
|
||||
// Extract and yield the value from input tensor at
|
||||
// `inputIndicesToExtract` indices.
|
||||
Value result = b.create<tensor::ExtractOp>(loc, input,
|
||||
inputIndicesToExtract);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TMTensorBroadcastToLinalgPass
|
||||
: public TMTensorBroadcastToLinalgBase<TMTensorBroadcastToLinalgPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, mlir::arith::ArithDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
torch::TMTensor::createTMTensorBroadcastToLinalgPass() {
|
||||
return std::make_unique<TMTensorBroadcastToLinalgPass>();
|
||||
}
|
|
@ -26,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
|||
MLIRLinalgDialect
|
||||
MLIRMathDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTMTensorDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToLinalg)
|
||||
|
|
|
@ -1162,57 +1162,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
SmallVector<Value> inShape;
|
||||
if (!getListConstructElements(adaptor.getSize(), inShape)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the size list is not from list construct");
|
||||
}
|
||||
// For dynamic input dimension we need to use the `broadcastToShape`
|
||||
// which in this case is `inShapeConverted` because this shape will yield
|
||||
// us the dimension size of the output.
|
||||
SmallVector<bool> useBroadcastToShape;
|
||||
for (auto x : inShape) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
|
||||
Operation* defOp = x.getDefiningOp();
|
||||
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
|
||||
useBroadcastToShape.push_back(true);
|
||||
else
|
||||
useBroadcastToShape.push_back(false);
|
||||
} else {
|
||||
useBroadcastToShape.push_back(false);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
Value result;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self,
|
||||
inShapeConverted, result,
|
||||
useBroadcastToShape))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
|
||||
public:
|
||||
|
@ -1231,74 +1180,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCopyOp : public OpConversionPattern<AtenCopyOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
Value src = adaptor.getSrc();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking is expected to be false");
|
||||
}
|
||||
|
||||
// The size of the src tensor can be different from the self but should be
|
||||
// broadcastable. Therefore, broadcasting the src tensor to match the size
|
||||
// of the self tensor.
|
||||
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
|
||||
for (unsigned i = 0; i < selfSizes.size(); i++)
|
||||
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
|
||||
Value broadcastedSrc;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, src, selfSizes, broadcastedSrc))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
|
||||
rewriter.getContext());
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
selfType.getRank(), utils::IteratorType::parallel);
|
||||
Value result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc,
|
||||
/*resultType=*/selfType,
|
||||
/*inputs=*/broadcastedSrc,
|
||||
/*outputs=*/self,
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = args[0];
|
||||
if (args[0].getType() != args[1].getType()) {
|
||||
result = convertScalarToDtype(b, loc, args[0],
|
||||
args[1].getType());
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
->getResult(0);
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSliceScatterOp
|
||||
: public OpConversionPattern<AtenSliceScatterOp> {
|
||||
|
@ -1449,12 +1330,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCatOp>();
|
||||
patterns.add<ConvertAtenCatOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCopyOp>();
|
||||
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceScatterOp>();
|
||||
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewAsComplexOp>();
|
||||
|
|
|
@ -159,274 +159,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the rank of both matrix.
|
||||
unsigned lhsRank = lhsType.getRank();
|
||||
unsigned rhsRank = rhsType.getRank();
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
auto resultType = newResultType.cast<RankedTensorType>();
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
// The different cases of torch_matmul op is mentioned here:
|
||||
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
|
||||
// First Case: Dot Product.
|
||||
if (lhsRank == 1 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
Value dotProd =
|
||||
rewriter
|
||||
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Second Case: Vec-Mat Multiplication.
|
||||
if (lhsRank == 1 && rhsRank == 2) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Third Case: Matrix-Vec Multiplication.
|
||||
if (lhsRank == 2 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Fourth Case: Batch-Matrix Multiplication.
|
||||
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
||||
// rank and the other has batch dimension.
|
||||
if (lhsRank > 1 && rhsRank > 1) {
|
||||
unsigned maxRank = std::max(lhsRank, rhsRank);
|
||||
unsigned minRank = std::min(lhsRank, rhsRank);
|
||||
unsigned batchRank = maxRank - 2;
|
||||
|
||||
// At least one of the matrix must have rank greater than 2.
|
||||
if (batchRank <= 0) {
|
||||
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
|
||||
}
|
||||
|
||||
// The `broadcastedBatchShape` contains batch dimensions of the resultant
|
||||
// matrix.
|
||||
SmallVector<Value> broadcastedBatchShape(batchRank);
|
||||
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
|
||||
Value maxDim;
|
||||
// Compute broadcasted batch dimensions if the batch dimensions of
|
||||
// the matrices are broadcastable.
|
||||
for (unsigned i = 1; i <= batchRank; i++) {
|
||||
if (i <= minRank - 2) {
|
||||
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
|
||||
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
|
||||
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
|
||||
} else {
|
||||
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
|
||||
}
|
||||
broadcastedBatchShape[batchRank - i] = maxDim;
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
// Compute broadcasted shape of both the matrices in integer format.
|
||||
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
|
||||
lhsBroadcastToShape.push_back(lhsDim0);
|
||||
lhsBroadcastToShape.push_back(lhsDim1);
|
||||
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
|
||||
rhsBroadcastToShape.push_back(rhsDim0);
|
||||
rhsBroadcastToShape.push_back(rhsDim1);
|
||||
for (unsigned i = 0; i < maxRank; i++) {
|
||||
lhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
|
||||
rhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
|
||||
}
|
||||
|
||||
// Broadcast the batch dimensions of both the matrices.
|
||||
Value broadcastedLhs, broadcastedRhs;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
if (maxRank == 3) {
|
||||
Value zeroTensor = createZeroInitTensor(
|
||||
rewriter, loc,
|
||||
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
|
||||
elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check if the result of the matrix multiplication has more than one
|
||||
// dynamic batch dimensions.
|
||||
SmallVector<int64_t> batchDimsInt =
|
||||
makeShapeTorchCompatible(resultType.getShape());
|
||||
batchDimsInt.pop_back();
|
||||
batchDimsInt.pop_back();
|
||||
bool multipleDynamicBatchDims =
|
||||
llvm::count(batchDimsInt, kUnknownSize) > 1;
|
||||
|
||||
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
|
||||
// at most one dynamic batch dimension due to limited support of the
|
||||
// `tensor.ExpandShape` op.
|
||||
if (!multipleDynamicBatchDims) {
|
||||
// Collapse the batch dimensions into one dimension. The resultant rank
|
||||
// will always be 3.
|
||||
SmallVector<ReassociationIndices> reassociation(3);
|
||||
for (unsigned i = 0, j = 0; i < maxRank; i++) {
|
||||
if (i >= batchRank)
|
||||
j++;
|
||||
reassociation[j].push_back(i);
|
||||
}
|
||||
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedLhs, reassociation);
|
||||
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedRhs, reassociation);
|
||||
|
||||
// Compute the result shape after collapsing the batch dimensions.
|
||||
SmallVector<Value> collapsedResultShape;
|
||||
collapsedResultShape.push_back(broadcastedBatchShape[0]);
|
||||
for (unsigned i = 1; i < batchRank; i++) {
|
||||
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
|
||||
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
|
||||
}
|
||||
collapsedResultShape.push_back(lhsDim0);
|
||||
collapsedResultShape.push_back(rhsDim1);
|
||||
SmallVector<OpFoldResult> updatedCollapseResultShape =
|
||||
getAsOpFoldResult(collapsedResultShape);
|
||||
|
||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, updatedCollapseResultShape, elementType);
|
||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
Value zeroTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
||||
Value batchMatMul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, resultType, batchMatMul, reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
expandResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr> lhsExpr;
|
||||
SmallVector<AffineExpr> rhsExpr;
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
batchRank, utils::IteratorType::parallel);
|
||||
for (unsigned i = 0; i < batchRank; i++) {
|
||||
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 1)});
|
||||
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
|
||||
SmallVector<Value> resultShape(broadcastedBatchShape);
|
||||
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
||||
auto indexingMaps =
|
||||
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
||||
iteratorTypes.insert(iteratorTypes.end(),
|
||||
{utils::IteratorType::parallel,
|
||||
utils::IteratorType::reduction,
|
||||
utils::IteratorType::parallel});
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value l = args[0], r = args[1], res = args[2];
|
||||
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
||||
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
||||
b.create<linalg::YieldOp>(loc, add);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
||||
public:
|
||||
|
@ -858,8 +590,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlipOp>();
|
||||
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMatmulOp>();
|
||||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBmmOp>();
|
||||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConvolutionOp>();
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
|
@ -47,6 +48,7 @@ public:
|
|||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
registry.insert<complex::ComplexDialect>();
|
||||
registry.insert<TMTensor::TMTensorDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
|
|
@ -320,113 +320,6 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|||
.getResult(0);
|
||||
}
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult torch_to_linalg::broadcastToGivenShape(
|
||||
Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result,
|
||||
SmallVector<bool> useBroadcastToShape) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> inputShape =
|
||||
makeShapeTorchCompatible(inputType.getShape());
|
||||
if (broadcastToShape.size() < inputShape.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
||||
"size of the input shape");
|
||||
}
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value zeroIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value oneIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
} else {
|
||||
// Case of dynamic input dimension wherein the shape to broadcast will
|
||||
// yield us the dimension size of the output.
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
if (!useBroadcastToShape.empty()) {
|
||||
if (useBroadcastToShape[i])
|
||||
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
|
||||
}
|
||||
outShape.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outShape), elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(broadcastToShape.size(),
|
||||
utils::IteratorType::parallel);
|
||||
result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), ValueRange(), outTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
// `loopIndices` contains IV of the linalg loops which
|
||||
// would be used to extract values from the input tensor
|
||||
// later on.
|
||||
SmallVector<Value> loopIndices;
|
||||
for (size_t i = 0; i < broadcastToShape.size(); ++i) {
|
||||
if (i < diff)
|
||||
continue;
|
||||
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
}
|
||||
// `inputIndicesToExtract` contains i-th linalg loop IV if
|
||||
// the i-th input dimension is not 1, else it contains a
|
||||
// zero index.
|
||||
SmallVector<Value> inputIndicesToExtract;
|
||||
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
|
||||
if (inputShape[i] == 1) {
|
||||
inputIndicesToExtract.push_back(zeroIndex);
|
||||
} else {
|
||||
Value inputDim = getDimOp(b, loc, input, i);
|
||||
Value isEqual = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isEqual, zeroIndex, loopIndices[i]);
|
||||
inputIndicesToExtract.push_back(select);
|
||||
}
|
||||
}
|
||||
// Extract and yield the value from input tensor at
|
||||
// `inputIndicesToExtract` indices.
|
||||
Value result = b.create<tensor::ExtractOp>(
|
||||
loc, input, inputIndicesToExtract);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
||||
Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -72,12 +71,6 @@ Value createElementwiseLinalgGeneric(
|
|||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult
|
||||
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result,
|
||||
SmallVector<bool> useBroadcastToShape = {});
|
||||
|
||||
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
|
||||
// <?x?xf32>
|
||||
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
|
||||
|
|
|
@ -80,6 +80,69 @@ static TypedAttr getNumericLimit(PatternRewriter &rewriter, Type elementType,
|
|||
}
|
||||
}
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
static LogicalResult
|
||||
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result,
|
||||
SmallVector<bool> useBroadcastToShape = {}) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> inputShape =
|
||||
makeShapeTorchCompatible(inputType.getShape());
|
||||
if (broadcastToShape.size() < inputShape.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
||||
"size of the input shape");
|
||||
}
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value oneIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value isNegative = rewriter.createOrFold<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.createOrFold<arith::SelectOp>(
|
||||
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
} else {
|
||||
// Case of dynamic input dimension wherein the shape to broadcast will
|
||||
// yield us the dimension size of the output.
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
if (!useBroadcastToShape.empty()) {
|
||||
if (useBroadcastToShape[i])
|
||||
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
|
||||
}
|
||||
outShape.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outShape), elementType);
|
||||
|
||||
result = rewriter
|
||||
.create<TMTensor::NumpyBroadcastOp>(
|
||||
op->getLoc(), outTensor.getType(), SmallVector<Value>{input},
|
||||
SmallVector<Value>{outTensor})
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// This function will reformat the `index` and `src` from torch operations
|
||||
// like `torch.scatter` or `torch.scatter_reduce` to match the expected
|
||||
// input for the TMScatterOp. It will return the reformated `index` and `src`
|
||||
|
@ -1626,6 +1689,392 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
SmallVector<Value> inShape;
|
||||
if (!getListConstructElements(adaptor.getSize(), inShape)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the size list is not from list construct");
|
||||
}
|
||||
// For dynamic input dimension we need to use the `broadcastToShape`
|
||||
// which in this case is `inShapeConverted` because this shape will yield
|
||||
// us the dimension size of the output.
|
||||
SmallVector<bool> useBroadcastToShape;
|
||||
for (auto x : inShape) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
|
||||
Operation *defOp = x.getDefiningOp();
|
||||
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
|
||||
useBroadcastToShape.push_back(true);
|
||||
else
|
||||
useBroadcastToShape.push_back(false);
|
||||
} else {
|
||||
useBroadcastToShape.push_back(false);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
Value result;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
|
||||
result, useBroadcastToShape))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCopyOp : public OpConversionPattern<AtenCopyOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
Value src = adaptor.getSrc();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking is expected to be false");
|
||||
}
|
||||
|
||||
// The size of the src tensor can be different from the self but should be
|
||||
// broadcastable. Therefore, broadcasting the src tensor to match the size
|
||||
// of the self tensor.
|
||||
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
|
||||
for (unsigned i = 0; i < selfSizes.size(); i++)
|
||||
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
|
||||
Value broadcastedSrc;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
|
||||
broadcastedSrc))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
|
||||
rewriter.getContext());
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
selfType.getRank(), utils::IteratorType::parallel);
|
||||
Value result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc,
|
||||
/*resultType=*/selfType,
|
||||
/*inputs=*/broadcastedSrc,
|
||||
/*outputs=*/self,
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = args[0];
|
||||
if (args[0].getType() != args[1].getType()) {
|
||||
result = convertScalarToDtype(b, loc, args[0],
|
||||
args[1].getType());
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
->getResult(0);
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the rank of both matrix.
|
||||
unsigned lhsRank = lhsType.getRank();
|
||||
unsigned rhsRank = rhsType.getRank();
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
auto resultType = newResultType.cast<RankedTensorType>();
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
// The different cases of torch_matmul op is mentioned here:
|
||||
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
|
||||
// First Case: Dot Product.
|
||||
if (lhsRank == 1 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
Value dotProd =
|
||||
rewriter
|
||||
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Second Case: Vec-Mat Multiplication.
|
||||
if (lhsRank == 1 && rhsRank == 2) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Third Case: Matrix-Vec Multiplication.
|
||||
if (lhsRank == 2 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Fourth Case: Batch-Matrix Multiplication.
|
||||
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
||||
// rank and the other has batch dimension.
|
||||
if (lhsRank > 1 && rhsRank > 1) {
|
||||
unsigned maxRank = std::max(lhsRank, rhsRank);
|
||||
unsigned minRank = std::min(lhsRank, rhsRank);
|
||||
unsigned batchRank = maxRank - 2;
|
||||
|
||||
// At least one of the matrix must have rank greater than 2.
|
||||
if (batchRank <= 0) {
|
||||
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
|
||||
}
|
||||
|
||||
// The `broadcastedBatchShape` contains batch dimensions of the resultant
|
||||
// matrix.
|
||||
SmallVector<Value> broadcastedBatchShape(batchRank);
|
||||
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
|
||||
Value maxDim;
|
||||
// Compute broadcasted batch dimensions if the batch dimensions of
|
||||
// the matrices are broadcastable.
|
||||
for (unsigned i = 1; i <= batchRank; i++) {
|
||||
if (i <= minRank - 2) {
|
||||
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
|
||||
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
|
||||
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
|
||||
} else {
|
||||
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
|
||||
}
|
||||
broadcastedBatchShape[batchRank - i] = maxDim;
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
// Compute broadcasted shape of both the matrices in integer format.
|
||||
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
|
||||
lhsBroadcastToShape.push_back(lhsDim0);
|
||||
lhsBroadcastToShape.push_back(lhsDim1);
|
||||
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
|
||||
rhsBroadcastToShape.push_back(rhsDim0);
|
||||
rhsBroadcastToShape.push_back(rhsDim1);
|
||||
for (unsigned i = 0; i < maxRank; i++) {
|
||||
lhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
|
||||
rhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
|
||||
}
|
||||
|
||||
// Broadcast the batch dimensions of both the matrices.
|
||||
Value broadcastedLhs, broadcastedRhs;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, lhs, lhsBroadcastToShape,
|
||||
broadcastedLhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
if (failed(broadcastToGivenShape(op, rewriter, rhs, rhsBroadcastToShape,
|
||||
broadcastedRhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
if (maxRank == 3) {
|
||||
Value zeroTensor = createZeroInitTensor(
|
||||
rewriter, loc,
|
||||
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
|
||||
elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check if the result of the matrix multiplication has more than one
|
||||
// dynamic batch dimensions.
|
||||
SmallVector<int64_t> batchDimsInt =
|
||||
makeShapeTorchCompatible(resultType.getShape());
|
||||
batchDimsInt.pop_back();
|
||||
batchDimsInt.pop_back();
|
||||
bool multipleDynamicBatchDims =
|
||||
llvm::count(batchDimsInt, kUnknownSize) > 1;
|
||||
|
||||
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
|
||||
// at most one dynamic batch dimension due to limited support of the
|
||||
// `tensor.ExpandShape` op.
|
||||
if (!multipleDynamicBatchDims) {
|
||||
// Collapse the batch dimensions into one dimension. The resultant rank
|
||||
// will always be 3.
|
||||
SmallVector<ReassociationIndices> reassociation(3);
|
||||
for (unsigned i = 0, j = 0; i < maxRank; i++) {
|
||||
if (i >= batchRank)
|
||||
j++;
|
||||
reassociation[j].push_back(i);
|
||||
}
|
||||
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedLhs, reassociation);
|
||||
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedRhs, reassociation);
|
||||
|
||||
// Compute the result shape after collapsing the batch dimensions.
|
||||
SmallVector<Value> collapsedResultShape;
|
||||
collapsedResultShape.push_back(broadcastedBatchShape[0]);
|
||||
for (unsigned i = 1; i < batchRank; i++) {
|
||||
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
|
||||
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
|
||||
}
|
||||
collapsedResultShape.push_back(lhsDim0);
|
||||
collapsedResultShape.push_back(rhsDim1);
|
||||
SmallVector<OpFoldResult> updatedCollapseResultShape =
|
||||
getAsOpFoldResult(collapsedResultShape);
|
||||
|
||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, updatedCollapseResultShape, elementType);
|
||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
Value zeroTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
||||
Value batchMatMul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, resultType, batchMatMul, reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
expandResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr> lhsExpr;
|
||||
SmallVector<AffineExpr> rhsExpr;
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
batchRank, utils::IteratorType::parallel);
|
||||
for (unsigned i = 0; i < batchRank; i++) {
|
||||
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 1)});
|
||||
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
|
||||
SmallVector<Value> resultShape(broadcastedBatchShape);
|
||||
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
||||
auto indexingMaps =
|
||||
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
||||
iteratorTypes.insert(iteratorTypes.end(),
|
||||
{utils::IteratorType::parallel,
|
||||
utils::IteratorType::reduction,
|
||||
utils::IteratorType::parallel});
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value l = args[0], r = args[1], res = args[2];
|
||||
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
||||
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
||||
b.create<linalg::YieldOp>(loc, add);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -1673,6 +2122,13 @@ public:
|
|||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||
context);
|
||||
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCopyOp>();
|
||||
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMatmulOp>();
|
||||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenScatterSrcOp>();
|
||||
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
||||
|
||||
|
|
|
@ -8,4 +8,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils
|
|||
MLIRArithDialect
|
||||
MLIRLinalgDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTMTensorDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue