[LINALG] Lower `aten.Matmul` to `linalg.BatchMatmul`

This commit lowers `aten.matmul` to `linalg.BatchMatmul` under the
following conditions:
1. The result of matrix multiplication must have batch dimensions,
   i.e., rank greater than 2.
2. The resultant matrix must have at most 1 dynamic batch dimension.

It also handles broadcasting of batch dimensions when batch dimensions
of the matrices are broadcastable.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/977/merge snapshot-20220625.514
Gaurav Shukla 2022-06-16 21:15:10 +05:30
parent f774e63abd
commit 1be604bfd3
7 changed files with 307 additions and 117 deletions

View File

@ -61,6 +61,7 @@ TOSA_PASS_SET = {
"ElementwisePowModule_basic",
"BmmModule_basic",
"MmDagModule_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_3d",
"RsubFloatModule_basic",

View File

@ -890,90 +890,6 @@ public:
};
} // namespace
// Broadcasts input tensor based on the broadcastToShape.
static LogicalResult broadcastToGivenShape(Operation *op,
ConversionPatternRewriter &rewriter,
Value input,
SmallVector<Value> broadcastToShape,
Value &result) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
if (broadcastToShape.size() < inputShape.size()) {
return rewriter.notifyMatchFailure(
op, "invalid shape: broadcastToShape size must not be smaller than the "
"size of the input shape");
}
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, input, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), input, outTensor, indexingMaps,
iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
return success();
}
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
@ -995,8 +911,8 @@ public:
rewriter, op.getLoc(), getTypeConverter(), inShape);
Value result;
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
result))) {
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, self, inShapeConverted, result))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
@ -1060,8 +976,8 @@ public:
for (unsigned i = 0; i < selfSizes.size(); i++)
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
Value broadcastedSrc;
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
broadcastedSrc))) {
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, src, selfSizes, broadcastedSrc))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}

View File

@ -164,12 +164,16 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>();
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
// Get the rank of both matrix.
unsigned lhsRank = lhsType.getRank();
unsigned rhsRank = rhsType.getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
auto resultType = newResultType.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
// The different cases of torch_matmul op is mentioned here:
// https://pytorch.org/docs/stable/generated/torch.matmul.html
@ -228,39 +232,134 @@ public:
}
// Fourth Case: Batch-Matrix Multiplication.
// TODO: Broadcasting of batch dimension is remaining.
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
// TODO: Handle batch matrix multiplication when one of the matrix is unity
// rank and the other has batch dimension.
if (lhsRank > 1 && rhsRank > 1) {
unsigned maxRank = std::max(lhsRank, rhsRank);
unsigned minRank = std::min(lhsRank, rhsRank);
unsigned batchRank = maxRank - 2;
unsigned batchRank = lhsRank - 2;
SmallVector<Value, 4> resultShape;
// At least one of the matrix must have rank greater than 2.
if (batchRank <= 0) {
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
}
// The `broadcastedBatchShape` contains batch dimensions of the resultant
// matrix.
SmallVector<Value> broadcastedBatchShape(batchRank);
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
Value maxDim;
// Compute broadcasted batch dimensions if the batch dimensions of
// the matrices are broadcastable.
for (unsigned i = 1; i <= batchRank; i++) {
if (i <= minRank - 2) {
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
} else {
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
}
broadcastedBatchShape[batchRank - i] = maxDim;
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Compute broadcasted shape of both the matrices in integer format.
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
lhsBroadcastToShape.push_back(lhsDim0);
lhsBroadcastToShape.push_back(lhsDim1);
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
rhsBroadcastToShape.push_back(rhsDim0);
rhsBroadcastToShape.push_back(rhsDim1);
for (unsigned i = 0; i < maxRank; i++) {
lhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
rhsBroadcastToShape[i] =
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
}
// Broadcast the batch dimensions of both the matrices.
Value broadcastedLhs, broadcastedRhs;
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
// Check if the result of the matrix multiplication has more than one
// dynamic batch dimensions.
ArrayRef<int64_t> batchDimsInt = resultType.getShape().drop_back(2);
bool multipleDynamicBatchDims =
llvm::count(batchDimsInt, kUnknownSize) > 1;
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
// at most one dynamic batch dimension due to limited support of the
// `tensor.ExpandShape` op.
if (!multipleDynamicBatchDims) {
// Collapse the batch dimensions into one dimension. The resultant rank
// will always be 3.
SmallVector<ReassociationIndices> reassociation(3);
for (unsigned i = 0, j = 0; i < maxRank; i++) {
if (i >= batchRank)
j++;
reassociation[j].push_back(i);
}
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedLhs, reassociation);
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), broadcastedRhs, reassociation);
// Compute the result shape after collapsing the batch dimensions.
SmallVector<Value> collapsedResultShape;
collapsedResultShape.push_back(broadcastedBatchShape[0]);
for (unsigned i = 1; i < batchRank; i++) {
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
}
collapsedResultShape.push_back(lhsDim0);
collapsedResultShape.push_back(rhsDim1);
SmallVector<OpFoldResult> updatedCollapseResultShape =
getAsOpFoldResult(collapsedResultShape);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, updatedCollapseResultShape, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Value batchMatMul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
.getResult(0);
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
loc, resultType, batchMatMul, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
expandResult);
return success();
}
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<StringRef> iteratorTypes;
// Since broadcasting is a TODO, check whether the lhs and rhs batch
// dimension match.
for (unsigned i = 0; i < batchRank; i++) {
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
resultShape.push_back(lhsBatch);
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Push the final matrix dimension.
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 1)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
@ -268,9 +367,10 @@ public:
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 2)});
Value initTensor0 =
SmallVector<Value> resultShape(broadcastedBatchShape);
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps =
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
iteratorTypes.insert(iteratorTypes.end(),
@ -279,7 +379,8 @@ public:
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0,
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {

View File

@ -256,3 +256,85 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
iteratorTypes, bodyBuild)
.getResult(0);
}
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult torch_to_linalg::broadcastToGivenShape(
Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
if (broadcastToShape.size() < inputShape.size()) {
return rewriter.notifyMatchFailure(
op, "invalid shape: broadcastToShape size must not be smaller than the "
"size of the input shape");
}
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, input, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), input, outTensor, indexingMaps,
iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
return success();
}

View File

@ -54,6 +54,13 @@ Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
Value input,
SmallVector<Value> broadcastToShape,
Value &result);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir

View File

@ -1129,8 +1129,10 @@ public:
SmallVector<int32_t> transposedLhsDims;
// Step: generate the common dim/shape information
bool hasDynamicDims = false;
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
hasDynamicDims |= isDynamicDim;
if (isDynamicDim ||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
commonValue *= lhsBroadcastedShape[dim];
@ -1138,11 +1140,13 @@ public:
}
}
// TODO: Handle the case when there are dynamic batch dimensions.
if (hasDynamicDims)
commonValue = ShapedType::kDynamicSize;
// Step: generate the LHS squeezed dim/shape information.
bool hasDynamicDims = false;
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
hasDynamicDims |= isDynamicDim;
if (!isDynamicDim &&
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
lhsSqueezedValue *= lhsBroadcastedShape[dim];

View File

@ -130,3 +130,82 @@ def Matmul_4d(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class Matmul4dStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6, 7], torch.float32, True),
([4, 5, 7, 6], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: Matmul4dStatic())
def Matmul4dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class MatmulStaticBroadcast(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 1, 6, 7], torch.float32, True),
([8, 1, 5, 7, 6], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
# ==============================================================================
class MatmulSingleDynamicBatchDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, -1, -1, -1], torch.float32, True),
([4, -1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim())
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class MatmulBroadcastBatchDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulBroadcastBatchDim())
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))