From 6c9ba4ce95ceeb0e62164abb3eb6f03054dd69ff Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 7 Jul 2023 22:31:51 +0530 Subject: [PATCH] [Torch-to-Linalg] Add dynamic dimension support for BroadcastTo op (#2174) -- This commit adds support for dynamic dimension in BroadcastTo op. Signed-off-by: Abhishek Varma --- e2e_testing/xfail_sets.py | 3 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 23 +++++- lib/Conversion/TorchToLinalg/Utils.cpp | 75 ++++++++++++------- lib/Conversion/TorchToLinalg/Utils.h | 8 +- .../torch_mlir_e2e_test/test_suite/basic.py | 26 +++++++ 5 files changed, 102 insertions(+), 33 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index e5e5082d4..3f21ad6ff 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -162,6 +162,9 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt 'SqrtIntConstantModule_basic', + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size + 'BroadcastDynamicDimModule_basic', + # START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int 'AtenIntBoolOpConstFalseModule_basic', 'AtenIntBoolOpConstTrueModule_basic', diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4877568a6..5f5d87c06 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1179,12 +1179,29 @@ public: return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } + // For dynamic input dimension we need to use the `broadcastToShape` + // which in this case is `inShapeConverted` because this shape will yield + // us the dimension size of the output. + SmallVector useBroadcastToShape; + for (auto x : inShape) { + int64_t dim; + if (!matchPattern(x, m_TorchConstantInt(&dim))) { + Operation* defOp = x.getDefiningOp(); + if (isa(defOp)) + useBroadcastToShape.push_back(true); + else + useBroadcastToShape.push_back(false); + } else { + useBroadcastToShape.push_back(false); + } + } + SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); - Value result; - if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, self, inShapeConverted, result))) { + if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self, + inShapeConverted, result, + useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 27299458d..4a47790b0 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -323,7 +323,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result) { + SmallVector broadcastToShape, Value &result, + SmallVector useBroadcastToShape) { RankedTensorType inputType = input.getType().cast(); SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); @@ -335,13 +336,16 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); SmallVector outShape; // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zeroIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); size_t diff = broadcastToShape.size() - inputShape.size(); for (size_t i = 0; i < broadcastToShape.size(); i++) { Value shapeValue = broadcastToShape[i]; @@ -358,46 +362,65 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } if (inputShape[j] == 1) { // Broadcast singleton dimension - Value one = - rewriter.create(loc, rewriter.getIndexAttr(1)); Value isNegative = rewriter.create( loc, arith::CmpIPredicate::slt, shapeValue, zero); Value select = rewriter.create( - loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); + loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); - outExpr.push_back(mlir::getAffineConstantExpr(0, context)); - continue; + } else { + // Case of dynamic input dimension wherein the shape to broadcast will + // yield us the dimension size of the output. + Value dim = getDimOp(rewriter, loc, input, j); + if (!useBroadcastToShape.empty()) { + if (useBroadcastToShape[i]) + dim = castIntToIndex(rewriter, loc, broadcastToShape[j]); + } + outShape.push_back(dim); } - // Non-broadcast case - Value dim = getDimOp(rewriter, loc, input, j); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value isEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim), - shapeValue); - Value isValid = rewriter.create(loc, isNegative, isEqual); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "only broadcasting singleton dimensions supported")); - outShape.push_back(dim); - outExpr.push_back(mlir::getAffineDimExpr(i, context)); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(outShape), elementType); SmallVector indexingMaps = { - AffineMap::get(broadcastToShape.size(), 0, outExpr, context), rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; SmallVector iteratorTypes(broadcastToShape.size(), utils::IteratorType::parallel); result = rewriter .create( - loc, outTensor.getType(), input, outTensor, indexingMaps, - iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); + loc, outTensor.getType(), ValueRange(), outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // `loopIndices` contains IV of the linalg loops which + // would be used to extract values from the input tensor + // later on. + SmallVector loopIndices; + for (size_t i = 0; i < broadcastToShape.size(); ++i) { + if (i < diff) + continue; + loopIndices.push_back(b.create(loc, i)); + } + // `inputIndicesToExtract` contains i-th linalg loop IV if + // the i-th input dimension is not 1, else it contains a + // zero index. + SmallVector inputIndicesToExtract; + for (size_t i = 0, n = inputShape.size(); i < n; i++) { + if (inputShape[i] == 1) { + inputIndicesToExtract.push_back(zeroIndex); + } else { + Value inputDim = getDimOp(b, loc, input, i); + Value isEqual = b.create( + loc, arith::CmpIPredicate::eq, inputDim, oneIndex); + Value select = rewriter.create( + loc, isEqual, zeroIndex, loopIndices[i]); + inputIndicesToExtract.push_back(select); + } + } + // Extract and yield the value from input tensor at + // `inputIndicesToExtract` indices. + Value result = b.create( + loc, input, inputIndicesToExtract); + b.create(loc, result); }) .getResult(0); diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 5fd5538c2..e60ecdc77 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -73,10 +73,10 @@ Value createElementwiseLinalgGeneric( function_ref bodyBuild); // Broadcasts input tensor based on the broadcastToShape. -LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, - Value input, - SmallVector broadcastToShape, - Value &result); +LogicalResult +broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input, + SmallVector broadcastToShape, Value &result, + SmallVector useBroadcastToShape = {}); // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 5e1c05f97..e0a603634 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1394,6 +1394,32 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastDynamicDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, -1, 1, -1], torch.float32, True), + ([1, -1, 1, -1], torch.float32, True), + ]) + def forward(self, x, y): + dim_at_index_1 = torch.ops.aten.size(x, 1) + dim_at_index_3 = torch.ops.aten.size(x, 3) + res = torch.ops.aten.broadcast_to(y, [1, dim_at_index_1, 1, dim_at_index_3]) + return res + + +@register_test_case(module_factory=lambda: BroadcastDynamicDimModule()) +def BroadcastDynamicDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 4), tu.rand(1, 1, 1, 1)) + + +# ============================================================================== + + class RollModule(torch.nn.Module): def __init__(self):