[Torch-to-Linalg] Add dynamic dimension support for BroadcastTo op (#2174)

-- This commit adds support for dynamic dimension in BroadcastTo op.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
pull/2292/head
Abhishek Varma 2023-07-07 22:31:51 +05:30 committed by GitHub
parent 7f4084b570
commit 6c9ba4ce95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 33 deletions

View File

@ -162,6 +162,9 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
'SqrtIntConstantModule_basic',
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
'BroadcastDynamicDimModule_basic',
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
'AtenIntBoolOpConstFalseModule_basic',
'AtenIntBoolOpConstTrueModule_basic',

View File

@ -1179,12 +1179,29 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
// For dynamic input dimension we need to use the `broadcastToShape`
// which in this case is `inShapeConverted` because this shape will yield
// us the dimension size of the output.
SmallVector<bool> useBroadcastToShape;
for (auto x : inShape) {
int64_t dim;
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
Operation* defOp = x.getDefiningOp();
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
useBroadcastToShape.push_back(true);
else
useBroadcastToShape.push_back(false);
} else {
useBroadcastToShape.push_back(false);
}
}
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
rewriter, op.getLoc(), getTypeConverter(), inShape);
Value result;
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, self, inShapeConverted, result))) {
if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self,
inShapeConverted, result,
useBroadcastToShape))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}

View File

@ -323,7 +323,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult torch_to_linalg::broadcastToGivenShape(
Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result) {
SmallVector<Value> broadcastToShape, Value &result,
SmallVector<bool> useBroadcastToShape) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(inputType.getShape());
@ -335,13 +336,16 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
Value zeroIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value oneIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
@ -358,46 +362,65 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
} else {
// Case of dynamic input dimension wherein the shape to broadcast will
// yield us the dimension size of the output.
Value dim = getDimOp(rewriter, loc, input, j);
if (!useBroadcastToShape.empty()) {
if (useBroadcastToShape[i])
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
}
outShape.push_back(dim);
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, input, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outShape), elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
SmallVector<utils::IteratorType> iteratorTypes(broadcastToShape.size(),
utils::IteratorType::parallel);
result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), input, outTensor, indexingMaps,
iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
loc, outTensor.getType(), ValueRange(), outTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// `loopIndices` contains IV of the linalg loops which
// would be used to extract values from the input tensor
// later on.
SmallVector<Value> loopIndices;
for (size_t i = 0; i < broadcastToShape.size(); ++i) {
if (i < diff)
continue;
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
}
// `inputIndicesToExtract` contains i-th linalg loop IV if
// the i-th input dimension is not 1, else it contains a
// zero index.
SmallVector<Value> inputIndicesToExtract;
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
if (inputShape[i] == 1) {
inputIndicesToExtract.push_back(zeroIndex);
} else {
Value inputDim = getDimOp(b, loc, input, i);
Value isEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
Value select = rewriter.create<arith::SelectOp>(
loc, isEqual, zeroIndex, loopIndices[i]);
inputIndicesToExtract.push_back(select);
}
}
// Extract and yield the value from input tensor at
// `inputIndicesToExtract` indices.
Value result = b.create<tensor::ExtractOp>(
loc, input, inputIndicesToExtract);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

View File

@ -73,10 +73,10 @@ Value createElementwiseLinalgGeneric(
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
// Broadcasts input tensor based on the broadcastToShape.
LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
Value input,
SmallVector<Value> broadcastToShape,
Value &result);
LogicalResult
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, Value &result,
SmallVector<bool> useBroadcastToShape = {});
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>

View File

@ -1394,6 +1394,32 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils):
# ==============================================================================
class BroadcastDynamicDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, -1, 1, -1], torch.float32, True),
([1, -1, 1, -1], torch.float32, True),
])
def forward(self, x, y):
dim_at_index_1 = torch.ops.aten.size(x, 1)
dim_at_index_3 = torch.ops.aten.size(x, 3)
res = torch.ops.aten.broadcast_to(y, [1, dim_at_index_1, 1, dim_at_index_3])
return res
@register_test_case(module_factory=lambda: BroadcastDynamicDimModule())
def BroadcastDynamicDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 1, 4), tu.rand(1, 1, 1, 1))
# ==============================================================================
class RollModule(torch.nn.Module):
def __init__(self):