Add lowerings for tm_tensor.npbroadcast

This shuffles the existing broadcasting pattern to an opt-in TMTensor
pass to lower tm_tensor.npbroadcast ops to a pessimistic linalg
equivalent. This maintains the semantics of a numpy broadcast past the
torch lowering stage.

Additionally relocates the lowering for ops that used numpy broadcasting
semantics (`aten.matmul`, `aten.broadcast_to`, and `aten.copy`) to the
conversion to TMTensor.
numpy_style_broadcast
Quinn Dawkins 2023-08-26 20:03:57 -04:00
parent 65bc15b340
commit a7f506adc4
15 changed files with 599 additions and 517 deletions

View File

@ -337,16 +337,15 @@ def TMTensor_NumpyBroadcastOp : TMTensor_Op<"npbroadcast",
let description = [{
}];
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs
Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs
);
let results = (outs Variadic<AnyRankedTensorOrMemRefType>:$results);
let regions = (region AnyRegion:$region);
let results = (outs Variadic<AnyShaped>:$results);
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
(`->` type($results)^)?
}];
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{

View File

@ -17,6 +17,8 @@ namespace mlir {
namespace torch {
namespace TMTensor {
std::unique_ptr<OperationPass<func::FuncOp>>
createTMTensorBroadcastToLinalgPass();
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();

View File

@ -12,6 +12,12 @@
include "mlir/Pass/PassBase.td"
def TMTensorBroadcastToLinalg :
Pass<"tm-tensor-broadcast-to-linalg", "func::FuncOp"> {
let summary = "Convert TMTensor NumpyBroadcastOps to linalg.";
let constructor = "mlir::torch::TMTensor::createTMTensorBroadcastToLinalgPass()";
}
def TMTensorToLoops :
Pass<"tm-tensor-to-loops", "func::FuncOp"> {
let summary = "Convert TMTensor ops to loops and Linalg ops.";

View File

@ -884,11 +884,6 @@ LogicalResult NumpyBroadcastOp::verify() {
if (getNumOutputs() != 1) {
return emitOpError("expected one output operand");
}
if (getInputRank() != getOutputRank()) {
return emitOpError("expected input and output ranks to be the same");
}
return success();
}

View File

@ -1,4 +1,5 @@
add_mlir_library(TorchMLIRTMTensorPasses
ConvertBroadcastToLinalg.cpp
ConvertToLoops.cpp
Bufferize.cpp
Passes.cpp

View File

@ -0,0 +1,122 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::torch::TMTensor;
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
namespace {
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(NumpyBroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
Location loc = broadcastOp.getLoc();
Value input = broadcastOp.getInput();
Value output = broadcastOp.getOutput();
auto inputType = input.getType().cast<RankedTensorType>();
auto outputType = output.getType().cast<RankedTensorType>();
int64_t diff = outputType.getRank() - inputType.getRank();
ArrayRef<int64_t> inputShape = inputType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
Value zeroIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value oneIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(outputType.getRank())};
SmallVector<utils::IteratorType> iteratorTypes(
outputType.getRank(), utils::IteratorType::parallel);
rewriter
.replaceOpWithNewOp<linalg::GenericOp>(
broadcastOp, output.getType(), ValueRange(), output, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// `loopIndices` contains IV of the linalg loops which
// would be used to extract values from the input tensor
// later on.
SmallVector<Value> loopIndices;
for (int64_t i = diff; i < outputType.getRank(); ++i) {
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
}
// `inputIndicesToExtract` contains i-th linalg loop IV if
// the i-th input dimension is not 1, else it contains a
// zero index.
SmallVector<Value> inputIndicesToExtract;
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
if (inputShape[i] == 1 && outputShape[i + diff] != 1) {
inputIndicesToExtract.push_back(zeroIndex);
} else {
Value inputDim = b.createOrFold<tensor::DimOp>(loc, input, i);
Value isEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
Value select = b.create<arith::SelectOp>(
loc, isEqual, zeroIndex, loopIndices[i]);
inputIndicesToExtract.push_back(select);
}
}
// Extract and yield the value from input tensor at
// `inputIndicesToExtract` indices.
Value result = b.create<tensor::ExtractOp>(loc, input,
inputIndicesToExtract);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
namespace {
struct TMTensorBroadcastToLinalgPass
: public TMTensorBroadcastToLinalgBase<TMTensorBroadcastToLinalgPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, mlir::arith::ArithDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
torch::TMTensor::createTMTensorBroadcastToLinalgPass() {
return std::make_unique<TMTensorBroadcastToLinalgPass>();
}

View File

@ -26,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
MLIRLinalgDialect
MLIRMathDialect
TorchMLIRTorchDialect
TorchMLIRTMTensorDialect
)
torch_mlir_target_includes(TorchMLIRTorchToLinalg)

View File

@ -1162,57 +1162,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value self = adaptor.getSelf();
SmallVector<Value> inShape;
if (!getListConstructElements(adaptor.getSize(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
// For dynamic input dimension we need to use the `broadcastToShape`
// which in this case is `inShapeConverted` because this shape will yield
// us the dimension size of the output.
SmallVector<bool> useBroadcastToShape;
for (auto x : inShape) {
int64_t dim;
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
Operation* defOp = x.getDefiningOp();
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
useBroadcastToShape.push_back(true);
else
useBroadcastToShape.push_back(false);
} else {
useBroadcastToShape.push_back(false);
}
}
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
rewriter, op.getLoc(), getTypeConverter(), inShape);
Value result;
if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self,
inShapeConverted, result,
useBroadcastToShape))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
public:
@ -1231,74 +1180,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenCopyOp : public OpConversionPattern<AtenCopyOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.getSelf();
Value src = adaptor.getSrc();
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
// The non_blocking should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking is expected to be false");
}
// The size of the src tensor can be different from the self but should be
// broadcastable. Therefore, broadcasting the src tensor to match the size
// of the self tensor.
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
for (unsigned i = 0; i < selfSizes.size(); i++)
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
Value broadcastedSrc;
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, src, selfSizes, broadcastedSrc))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes(
selfType.getRank(), utils::IteratorType::parallel);
Value result = rewriter
.create<linalg::GenericOp>(
loc,
/*resultType=*/selfType,
/*inputs=*/broadcastedSrc,
/*outputs=*/self,
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0];
if (args[0].getType() != args[1].getType()) {
result = convertScalarToDtype(b, loc, args[0],
args[1].getType());
}
b.create<linalg::YieldOp>(loc, result);
})
->getResult(0);
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenSliceScatterOp
: public OpConversionPattern<AtenSliceScatterOp> {
@ -1449,12 +1330,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
target.addIllegalOp<AtenCatOp>();
patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenCopyOp>();
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsComplexOp>();

View File

@ -159,274 +159,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>();
// Get the rank of both matrix.
unsigned lhsRank = lhsType.getRank();
unsigned rhsRank = rhsType.getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
// The different cases of torch_matmul op is mentioned here:
// https://pytorch.org/docs/stable/generated/torch.matmul.html
// First Case: Dot Product.
if (lhsRank == 1 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
Value dotProd =
rewriter
.create<linalg::DotOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
return success();
}
// Second Case: Vec-Mat Multiplication.
if (lhsRank == 1 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
Value matmul =
rewriter
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Third Case: Matrix-Vec Multiplication.
if (lhsRank == 2 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
Value matmul =
rewriter
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Fourth Case: Batch-Matrix Multiplication.
// TODO: Handle batch matrix multiplication when one of the matrix is unity
// rank and the other has batch dimension.
if (lhsRank > 1 && rhsRank > 1) {
unsigned maxRank = std::max(lhsRank, rhsRank);
unsigned minRank = std::min(lhsRank, rhsRank);
unsigned batchRank = maxRank - 2;
// At least one of the matrix must have rank greater than 2.
if (batchRank <= 0) {
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
}
// The `broadcastedBatchShape` contains batch dimensions of the resultant
// matrix.
SmallVector<Value> broadcastedBatchShape(batchRank);
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
Value maxDim;
// Compute broadcasted batch dimensions if the batch dimensions of
// the matrices are broadcastable.
for (unsigned i = 1; i <= batchRank; i++) {
if (i <= minRank - 2) {
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
} else {
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
}
broadcastedBatchShape[batchRank - i] = maxDim;
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Compute broadcasted shape of both the matrices in integer format.
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
lhsBroadcastToShape.push_back(lhsDim0);
lhsBroadcastToShape.push_back(lhsDim1);
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
rhsBroadcastToShape.push_back(rhsDim0);
rhsBroadcastToShape.push_back(rhsDim1);
for (unsigned i = 0; i < maxRank; i++) {
lhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
rhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
}
// Broadcast the batch dimensions of both the matrices.
Value broadcastedLhs, broadcastedRhs;
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
if (maxRank == 3) {
Value zeroTensor = createZeroInitTensor(
rewriter, loc,
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
elementType);
Value matmul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Check if the result of the matrix multiplication has more than one
// dynamic batch dimensions.
SmallVector<int64_t> batchDimsInt =
makeShapeTorchCompatible(resultType.getShape());
batchDimsInt.pop_back();
batchDimsInt.pop_back();
bool multipleDynamicBatchDims =
llvm::count(batchDimsInt, kUnknownSize) > 1;
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
// at most one dynamic batch dimension due to limited support of the
// `tensor.ExpandShape` op.
if (!multipleDynamicBatchDims) {
// Collapse the batch dimensions into one dimension. The resultant rank
// will always be 3.
SmallVector<ReassociationIndices> reassociation(3);
for (unsigned i = 0, j = 0; i < maxRank; i++) {
if (i >= batchRank)
j++;
reassociation[j].push_back(i);
}
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedLhs, reassociation);
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedRhs, reassociation);
// Compute the result shape after collapsing the batch dimensions.
SmallVector<Value> collapsedResultShape;
collapsedResultShape.push_back(broadcastedBatchShape[0]);
for (unsigned i = 1; i < batchRank; i++) {
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
}
collapsedResultShape.push_back(lhsDim0);
collapsedResultShape.push_back(rhsDim1);
SmallVector<OpFoldResult> updatedCollapseResultShape =
getAsOpFoldResult(collapsedResultShape);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, updatedCollapseResultShape, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Value batchMatMul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
.getResult(0);
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
loc, resultType, batchMatMul, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
expandResult);
return success();
}
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<utils::IteratorType> iteratorTypes(
batchRank, utils::IteratorType::parallel);
for (unsigned i = 0; i < batchRank; i++) {
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
}
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 1)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
rewriter.getAffineDimExpr(batchRank + 2)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 2)});
SmallVector<Value> resultShape(broadcastedBatchShape);
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps =
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
iteratorTypes.insert(iteratorTypes.end(),
{utils::IteratorType::parallel,
utils::IteratorType::reduction,
utils::IteratorType::parallel});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value mul = b.create<arith::MulFOp>(loc, l, r);
Value add = b.create<arith::AddFOp>(loc, mul, res);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
}
return failure();
}
};
} // namespace
namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
public:
@ -858,8 +590,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenMmOp>(typeConverter, context);
target.addIllegalOp<AtenFlipOp>();
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
@ -47,6 +48,7 @@ public:
registry.insert<arith::ArithDialect>();
registry.insert<cf::ControlFlowDialect>();
registry.insert<complex::ComplexDialect>();
registry.insert<TMTensor::TMTensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

View File

@ -320,113 +320,6 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
.getResult(0);
}
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult torch_to_linalg::broadcastToGivenShape(
Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result,
SmallVector<bool> useBroadcastToShape) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(inputType.getShape());
if (broadcastToShape.size() < inputShape.size()) {
return rewriter.notifyMatchFailure(
op, "invalid shape: broadcastToShape size must not be smaller than the "
"size of the input shape");
}
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
Value zeroIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value oneIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
} else {
// Case of dynamic input dimension wherein the shape to broadcast will
// yield us the dimension size of the output.
Value dim = getDimOp(rewriter, loc, input, j);
if (!useBroadcastToShape.empty()) {
if (useBroadcastToShape[i])
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
}
outShape.push_back(dim);
}
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outShape), elementType);
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
SmallVector<utils::IteratorType> iteratorTypes(broadcastToShape.size(),
utils::IteratorType::parallel);
result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange(), outTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// `loopIndices` contains IV of the linalg loops which
// would be used to extract values from the input tensor
// later on.
SmallVector<Value> loopIndices;
for (size_t i = 0; i < broadcastToShape.size(); ++i) {
if (i < diff)
continue;
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
}
// `inputIndicesToExtract` contains i-th linalg loop IV if
// the i-th input dimension is not 1, else it contains a
// zero index.
SmallVector<Value> inputIndicesToExtract;
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
if (inputShape[i] == 1) {
inputIndicesToExtract.push_back(zeroIndex);
} else {
Value inputDim = getDimOp(b, loc, input, i);
Value isEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
Value select = rewriter.create<arith::SelectOp>(
loc, isEqual, zeroIndex, loopIndices[i]);
inputIndicesToExtract.push_back(select);
}
}
// Extract and yield the value from input tensor at
// `inputIndicesToExtract` indices.
Value result = b.create<tensor::ExtractOp>(
loc, input, inputIndicesToExtract);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
return success();
}
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();

View File

@ -1,4 +1,3 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -72,12 +71,6 @@ Value createElementwiseLinalgGeneric(
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result,
SmallVector<bool> useBroadcastToShape = {});
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);

View File

@ -80,6 +80,69 @@ static TypedAttr getNumericLimit(PatternRewriter &rewriter, Type elementType,
}
}
// Broadcasts input tensor based on the broadcastToShape.
static LogicalResult
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result,
SmallVector<bool> useBroadcastToShape = {}) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(inputType.getShape());
if (broadcastToShape.size() < inputShape.size()) {
return rewriter.notifyMatchFailure(
op, "invalid shape: broadcastToShape size must not be smaller than the "
"size of the input shape");
}
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
Value oneIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value isNegative = rewriter.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.createOrFold<arith::SelectOp>(
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
} else {
// Case of dynamic input dimension wherein the shape to broadcast will
// yield us the dimension size of the output.
Value dim = getDimOp(rewriter, loc, input, j);
if (!useBroadcastToShape.empty()) {
if (useBroadcastToShape[i])
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
}
outShape.push_back(dim);
}
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outShape), elementType);
result = rewriter
.create<TMTensor::NumpyBroadcastOp>(
op->getLoc(), outTensor.getType(), SmallVector<Value>{input},
SmallVector<Value>{outTensor})
.getResult(0);
return success();
}
// This function will reformat the `index` and `src` from torch operations
// like `torch.scatter` or `torch.scatter_reduce` to match the expected
// input for the TMScatterOp. It will return the reformated `index` and `src`
@ -1626,6 +1689,392 @@ public:
};
} // namespace
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value self = adaptor.getSelf();
SmallVector<Value> inShape;
if (!getListConstructElements(adaptor.getSize(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
// For dynamic input dimension we need to use the `broadcastToShape`
// which in this case is `inShapeConverted` because this shape will yield
// us the dimension size of the output.
SmallVector<bool> useBroadcastToShape;
for (auto x : inShape) {
int64_t dim;
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
Operation *defOp = x.getDefiningOp();
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
useBroadcastToShape.push_back(true);
else
useBroadcastToShape.push_back(false);
} else {
useBroadcastToShape.push_back(false);
}
}
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
rewriter, op.getLoc(), getTypeConverter(), inShape);
Value result;
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
result, useBroadcastToShape))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenCopyOp : public OpConversionPattern<AtenCopyOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.getSelf();
Value src = adaptor.getSrc();
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
// The non_blocking should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking is expected to be false");
}
// The size of the src tensor can be different from the self but should be
// broadcastable. Therefore, broadcasting the src tensor to match the size
// of the self tensor.
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
for (unsigned i = 0; i < selfSizes.size(); i++)
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
Value broadcastedSrc;
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
broadcastedSrc))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes(
selfType.getRank(), utils::IteratorType::parallel);
Value result = rewriter
.create<linalg::GenericOp>(
loc,
/*resultType=*/selfType,
/*inputs=*/broadcastedSrc,
/*outputs=*/self,
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0];
if (args[0].getType() != args[1].getType()) {
result = convertScalarToDtype(b, loc, args[0],
args[1].getType());
}
b.create<linalg::YieldOp>(loc, result);
})
->getResult(0);
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>();
// Get the rank of both matrix.
unsigned lhsRank = lhsType.getRank();
unsigned rhsRank = rhsType.getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
// The different cases of torch_matmul op is mentioned here:
// https://pytorch.org/docs/stable/generated/torch.matmul.html
// First Case: Dot Product.
if (lhsRank == 1 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
Value dotProd =
rewriter
.create<linalg::DotOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
return success();
}
// Second Case: Vec-Mat Multiplication.
if (lhsRank == 1 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
Value matmul =
rewriter
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Third Case: Matrix-Vec Multiplication.
if (lhsRank == 2 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
Value matmul =
rewriter
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Fourth Case: Batch-Matrix Multiplication.
// TODO: Handle batch matrix multiplication when one of the matrix is unity
// rank and the other has batch dimension.
if (lhsRank > 1 && rhsRank > 1) {
unsigned maxRank = std::max(lhsRank, rhsRank);
unsigned minRank = std::min(lhsRank, rhsRank);
unsigned batchRank = maxRank - 2;
// At least one of the matrix must have rank greater than 2.
if (batchRank <= 0) {
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
}
// The `broadcastedBatchShape` contains batch dimensions of the resultant
// matrix.
SmallVector<Value> broadcastedBatchShape(batchRank);
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
Value maxDim;
// Compute broadcasted batch dimensions if the batch dimensions of
// the matrices are broadcastable.
for (unsigned i = 1; i <= batchRank; i++) {
if (i <= minRank - 2) {
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
} else {
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
}
broadcastedBatchShape[batchRank - i] = maxDim;
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Compute broadcasted shape of both the matrices in integer format.
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
lhsBroadcastToShape.push_back(lhsDim0);
lhsBroadcastToShape.push_back(lhsDim1);
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
rhsBroadcastToShape.push_back(rhsDim0);
rhsBroadcastToShape.push_back(rhsDim1);
for (unsigned i = 0; i < maxRank; i++) {
lhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
rhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
}
// Broadcast the batch dimensions of both the matrices.
Value broadcastedLhs, broadcastedRhs;
if (failed(broadcastToGivenShape(op, rewriter, lhs, lhsBroadcastToShape,
broadcastedLhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
if (failed(broadcastToGivenShape(op, rewriter, rhs, rhsBroadcastToShape,
broadcastedRhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
if (maxRank == 3) {
Value zeroTensor = createZeroInitTensor(
rewriter, loc,
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
elementType);
Value matmul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Check if the result of the matrix multiplication has more than one
// dynamic batch dimensions.
SmallVector<int64_t> batchDimsInt =
makeShapeTorchCompatible(resultType.getShape());
batchDimsInt.pop_back();
batchDimsInt.pop_back();
bool multipleDynamicBatchDims =
llvm::count(batchDimsInt, kUnknownSize) > 1;
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
// at most one dynamic batch dimension due to limited support of the
// `tensor.ExpandShape` op.
if (!multipleDynamicBatchDims) {
// Collapse the batch dimensions into one dimension. The resultant rank
// will always be 3.
SmallVector<ReassociationIndices> reassociation(3);
for (unsigned i = 0, j = 0; i < maxRank; i++) {
if (i >= batchRank)
j++;
reassociation[j].push_back(i);
}
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedLhs, reassociation);
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedRhs, reassociation);
// Compute the result shape after collapsing the batch dimensions.
SmallVector<Value> collapsedResultShape;
collapsedResultShape.push_back(broadcastedBatchShape[0]);
for (unsigned i = 1; i < batchRank; i++) {
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
}
collapsedResultShape.push_back(lhsDim0);
collapsedResultShape.push_back(rhsDim1);
SmallVector<OpFoldResult> updatedCollapseResultShape =
getAsOpFoldResult(collapsedResultShape);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, updatedCollapseResultShape, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Value batchMatMul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
.getResult(0);
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
loc, resultType, batchMatMul, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
expandResult);
return success();
}
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<utils::IteratorType> iteratorTypes(
batchRank, utils::IteratorType::parallel);
for (unsigned i = 0; i < batchRank; i++) {
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
}
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 1)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
rewriter.getAffineDimExpr(batchRank + 2)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 2)});
SmallVector<Value> resultShape(broadcastedBatchShape);
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps =
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
iteratorTypes.insert(iteratorTypes.end(),
{utils::IteratorType::parallel,
utils::IteratorType::reduction,
utils::IteratorType::parallel});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value mul = b.create<arith::MulFOp>(loc, l, r);
Value add = b.create<arith::AddFOp>(loc, mul, res);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
}
return failure();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -1673,6 +2122,13 @@ public:
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenCopyOp>();
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenScatterSrcOp>();
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);

View File

@ -8,4 +8,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils
MLIRArithDialect
MLIRLinalgDialect
TorchMLIRTorchDialect
TorchMLIRTMTensorDialect
)
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)

View File

@ -14,6 +14,8 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"