diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index fb8131ac7..ad4fd283a 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -61,6 +61,7 @@ TOSA_PASS_SET = { "ElementwisePowModule_basic", "BmmModule_basic", "MmDagModule_basic", + "Matmul4dStatic_basic", "Matmul_dot", "Matmul_3d", "RsubFloatModule_basic", diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 7884113b5..d0e3ae7d0 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -890,90 +890,6 @@ public: }; } // namespace -// Broadcasts input tensor based on the broadcastToShape. -static LogicalResult broadcastToGivenShape(Operation *op, - ConversionPatternRewriter &rewriter, - Value input, - SmallVector broadcastToShape, - Value &result) { - RankedTensorType inputType = input.getType().cast(); - ArrayRef inputShape = inputType.getShape(); - if (broadcastToShape.size() < inputShape.size()) { - return rewriter.notifyMatchFailure( - op, "invalid shape: broadcastToShape size must not be smaller than the " - "size of the input shape"); - } - - Type elementType = inputType.getElementType(); - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - SmallVector outShape; - - // Create affine map and shapes for tensor initialization. - SmallVector outExpr; - Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - size_t diff = broadcastToShape.size() - inputShape.size(); - for (size_t i = 0; i < broadcastToShape.size(); i++) { - Value shapeValue = broadcastToShape[i]; - size_t j = i - diff; - if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); - outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); - continue; - } - if (inputShape[j] == 1) { - // Broadcast singleton dimension - Value one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value select = rewriter.create( - loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); - outShape.push_back(select); - outExpr.push_back(mlir::getAffineConstantExpr(0, context)); - continue; - } - // Non-broadcast case - Value dim = getDimOp(rewriter, loc, input, j); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value isEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim), - shapeValue); - Value isValid = rewriter.create(loc, isNegative, isEqual); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "only broadcasting singleton dimensions supported")); - outShape.push_back(dim); - outExpr.push_back(mlir::getAffineDimExpr(i, context)); - } - - Value outTensor = - rewriter.create(loc, outShape, elementType); - - SmallVector indexingMaps = { - AffineMap::get(broadcastToShape.size(), 0, outExpr, context), - rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; - SmallVector iteratorTypes(broadcastToShape.size(), "parallel"); - result = rewriter - .create( - loc, outTensor.getType(), input, outTensor, indexingMaps, - iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - - return success(); -} - namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: @@ -995,8 +911,8 @@ public: rewriter, op.getLoc(), getTypeConverter(), inShape); Value result; - if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted, - result))) { + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, self, inShapeConverted, result))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } @@ -1060,8 +976,8 @@ public: for (unsigned i = 0; i < selfSizes.size(); i++) selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]); Value broadcastedSrc; - if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes, - broadcastedSrc))) { + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, src, selfSizes, broadcastedSrc))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 8fa598f8d..550e41d27 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -164,12 +164,16 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + auto lhsType = lhs.getType().cast(); + auto rhsType = rhs.getType().cast(); - unsigned lhsRank = lhs.getType().cast().getRank(); - unsigned rhsRank = rhs.getType().cast().getRank(); + // Get the rank of both matrix. + unsigned lhsRank = lhsType.getRank(); + unsigned rhsRank = rhsType.getRank(); Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = newResultType.cast().getElementType(); + auto resultType = newResultType.cast(); + Type elementType = resultType.getElementType(); // The different cases of torch_matmul op is mentioned here: // https://pytorch.org/docs/stable/generated/torch.matmul.html @@ -228,39 +232,134 @@ public: } // Fourth Case: Batch-Matrix Multiplication. - // TODO: Broadcasting of batch dimension is remaining. - if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) { + // TODO: Handle batch matrix multiplication when one of the matrix is unity + // rank and the other has batch dimension. + if (lhsRank > 1 && rhsRank > 1) { + unsigned maxRank = std::max(lhsRank, rhsRank); + unsigned minRank = std::min(lhsRank, rhsRank); + unsigned batchRank = maxRank - 2; - unsigned batchRank = lhsRank - 2; - SmallVector resultShape; + // At least one of the matrix must have rank greater than 2. + if (batchRank <= 0) { + return rewriter.notifyMatchFailure(op, "expected batch dimensions"); + } + + // The `broadcastedBatchShape` contains batch dimensions of the resultant + // matrix. + SmallVector broadcastedBatchShape(batchRank); + Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs; + Value maxDim; + // Compute broadcasted batch dimensions if the batch dimensions of + // the matrices are broadcastable. + for (unsigned i = 1; i <= batchRank; i++) { + if (i <= minRank - 2) { + Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i); + Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i); + maxDim = rewriter.createOrFold(loc, lhsDim, rhsDim); + } else { + maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i); + } + broadcastedBatchShape[batchRank - i] = maxDim; + } + + Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + // Compute broadcasted shape of both the matrices in integer format. + SmallVector lhsBroadcastToShape(broadcastedBatchShape); + lhsBroadcastToShape.push_back(lhsDim0); + lhsBroadcastToShape.push_back(lhsDim1); + SmallVector rhsBroadcastToShape(broadcastedBatchShape); + rhsBroadcastToShape.push_back(rhsDim0); + rhsBroadcastToShape.push_back(rhsDim1); + for (unsigned i = 0; i < maxRank; i++) { + lhsBroadcastToShape[i] = + castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]); + rhsBroadcastToShape[i] = + castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]); + } + + // Broadcast the batch dimensions of both the matrices. + Value broadcastedLhs, broadcastedRhs; + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) { + return rewriter.notifyMatchFailure( + op, "unable to perform broadcast operation"); + } + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) { + return rewriter.notifyMatchFailure( + op, "unable to perform broadcast operation"); + } + + // Check if the result of the matrix multiplication has more than one + // dynamic batch dimensions. + ArrayRef batchDimsInt = resultType.getShape().drop_back(2); + bool multipleDynamicBatchDims = + llvm::count(batchDimsInt, kUnknownSize) > 1; + + // TODO: Lowering to `linalg.BatchMatmul` is only possible when there is + // at most one dynamic batch dimension due to limited support of the + // `tensor.ExpandShape` op. + if (!multipleDynamicBatchDims) { + // Collapse the batch dimensions into one dimension. The resultant rank + // will always be 3. + SmallVector reassociation(3); + for (unsigned i = 0, j = 0; i < maxRank; i++) { + if (i >= batchRank) + j++; + reassociation[j].push_back(i); + } + Value collapsedLhs = rewriter.create( + op->getLoc(), broadcastedLhs, reassociation); + Value collapsedRhs = rewriter.create( + op->getLoc(), broadcastedRhs, reassociation); + + // Compute the result shape after collapsing the batch dimensions. + SmallVector collapsedResultShape; + collapsedResultShape.push_back(broadcastedBatchShape[0]); + for (unsigned i = 1; i < batchRank; i++) { + collapsedResultShape[0] = rewriter.createOrFold( + loc, collapsedResultShape[0], broadcastedBatchShape[i]); + } + collapsedResultShape.push_back(lhsDim0); + collapsedResultShape.push_back(rhsDim1); + SmallVector updatedCollapseResultShape = + getAsOpFoldResult(collapsedResultShape); + + Value initTensor = rewriter.create( + loc, updatedCollapseResultShape, elementType); + Value c0 = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + Value zeroTensor = + rewriter.create(loc, c0, initTensor).getResult(0); + + Value batchMatMul = + rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) + .getResult(0); + Value expandResult = rewriter.create( + loc, resultType, batchMatMul, reassociation); + rewriter.replaceOpWithNewOp(op, newResultType, + expandResult); + return success(); + } SmallVector lhsExpr; SmallVector rhsExpr; SmallVector outExpr; SmallVector iteratorTypes; - - // Since broadcasting is a TODO, check whether the lhs and rhs batch - // dimension match. for (unsigned i = 0; i < batchRank; i++) { - Value lhsBatch = getDimOp(rewriter, loc, lhs, i); - Value rhsBatch = getDimOp(rewriter, loc, rhs, i); - resultShape.push_back(lhsBatch); lhsExpr.push_back(rewriter.getAffineDimExpr(i)); rhsExpr.push_back(rewriter.getAffineDimExpr(i)); outExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); - checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch); } - - Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank); - Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1); - Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank); - Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1); - checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); - - // Push the final matrix dimension. - resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); - lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 1)}); rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1), @@ -268,9 +367,10 @@ public: outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 2)}); - Value initTensor0 = + SmallVector resultShape(broadcastedBatchShape); + resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); + Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, elementType); - auto indexingMaps = AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); iteratorTypes.insert(iteratorTypes.end(), @@ -279,7 +379,8 @@ public: Value finalRes = rewriter .create( - loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0, + loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 4d8492b45..0d04dc552 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -256,3 +256,85 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( iteratorTypes, bodyBuild) .getResult(0); } + +// Broadcasts input tensor based on the broadcastToShape. +LogicalResult torch_to_linalg::broadcastToGivenShape( + Operation *op, PatternRewriter &rewriter, Value input, + SmallVector broadcastToShape, Value &result) { + RankedTensorType inputType = input.getType().cast(); + ArrayRef inputShape = inputType.getShape(); + if (broadcastToShape.size() < inputShape.size()) { + return rewriter.notifyMatchFailure( + op, "invalid shape: broadcastToShape size must not be smaller than the " + "size of the input shape"); + } + + Type elementType = inputType.getElementType(); + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + SmallVector outShape; + + // Create affine map and shapes for tensor initialization. + SmallVector outExpr; + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + size_t diff = broadcastToShape.size() - inputShape.size(); + for (size_t i = 0; i < broadcastToShape.size(); i++) { + Value shapeValue = broadcastToShape[i]; + size_t j = i - diff; + if (i < diff) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); + continue; + } + if (inputShape[j] == 1) { + // Broadcast singleton dimension + Value one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value select = rewriter.create( + loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); + outShape.push_back(select); + outExpr.push_back(mlir::getAffineConstantExpr(0, context)); + continue; + } + // Non-broadcast case + Value dim = getDimOp(rewriter, loc, input, j); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value isEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim), + shapeValue); + Value isValid = rewriter.create(loc, isNegative, isEqual); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "only broadcasting singleton dimensions supported")); + outShape.push_back(dim); + outExpr.push_back(mlir::getAffineDimExpr(i, context)); + } + + Value outTensor = + rewriter.create(loc, outShape, elementType); + + SmallVector indexingMaps = { + AffineMap::get(broadcastToShape.size(), 0, outExpr, context), + rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; + SmallVector iteratorTypes(broadcastToShape.size(), "parallel"); + result = rewriter + .create( + loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + + return success(); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index eb16387e0..6279b8c9e 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -54,6 +54,13 @@ Value createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, function_ref bodyBuild); + +// Broadcasts input tensor based on the broadcastToShape. +LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, + Value input, + SmallVector broadcastToShape, + Value &result); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 75d8f6da8..1e09d9527 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1129,8 +1129,10 @@ public: SmallVector transposedLhsDims; // Step: generate the common dim/shape information + bool hasDynamicDims = false; for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]); + hasDynamicDims |= isDynamicDim; if (isDynamicDim || lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { commonValue *= lhsBroadcastedShape[dim]; @@ -1138,11 +1140,13 @@ public: } } + // TODO: Handle the case when there are dynamic batch dimensions. + if (hasDynamicDims) + commonValue = ShapedType::kDynamicSize; + // Step: generate the LHS squeezed dim/shape information. - bool hasDynamicDims = false; for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]); - hasDynamicDims |= isDynamicDim; if (!isDynamicDim && lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) { lhsSqueezedValue *= lhsBroadcastedShape[dim]; diff --git a/python/torch_mlir_e2e_test/test_suite/matmul.py b/python/torch_mlir_e2e_test/test_suite/matmul.py index 09d8f50a0..5b9502cae 100644 --- a/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -130,3 +130,82 @@ def Matmul_4d(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) # ============================================================================== + +class Matmul4dStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6, 7], torch.float32, True), + ([4, 5, 7, 6], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul4dStatic()) +def Matmul4dStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + +# ============================================================================== + +class MatmulStaticBroadcast(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 1, 6, 7], torch.float32, True), + ([8, 1, 5, 7, 6], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulStaticBroadcast()) +def MatmulStaticBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6)) + +# ============================================================================== + +class MatmulSingleDynamicBatchDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, -1, -1, -1], torch.float32, True), + ([4, -1, -1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim()) +def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + +# ============================================================================== + +class MatmulBroadcastBatchDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, -1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulBroadcastBatchDim()) +def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) +