mirror of https://github.com/llvm/torch-mlir
[Torch-to-Linalg] Add dynamic dimension support for BroadcastTo op (#2174)
-- This commit adds support for dynamic dimension in BroadcastTo op. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/2292/head
parent
7f4084b570
commit
6c9ba4ce95
|
@ -162,6 +162,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
||||
'SqrtIntConstantModule_basic',
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
||||
'BroadcastDynamicDimModule_basic',
|
||||
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
'AtenIntBoolOpConstFalseModule_basic',
|
||||
'AtenIntBoolOpConstTrueModule_basic',
|
||||
|
|
|
@ -1179,12 +1179,29 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the size list is not from list construct");
|
||||
}
|
||||
// For dynamic input dimension we need to use the `broadcastToShape`
|
||||
// which in this case is `inShapeConverted` because this shape will yield
|
||||
// us the dimension size of the output.
|
||||
SmallVector<bool> useBroadcastToShape;
|
||||
for (auto x : inShape) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
|
||||
Operation* defOp = x.getDefiningOp();
|
||||
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
|
||||
useBroadcastToShape.push_back(true);
|
||||
else
|
||||
useBroadcastToShape.push_back(false);
|
||||
} else {
|
||||
useBroadcastToShape.push_back(false);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
|
||||
Value result;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, self, inShapeConverted, result))) {
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self,
|
||||
inShapeConverted, result,
|
||||
useBroadcastToShape))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
|
|
@ -323,7 +323,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult torch_to_linalg::broadcastToGivenShape(
|
||||
Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result) {
|
||||
SmallVector<Value> broadcastToShape, Value &result,
|
||||
SmallVector<bool> useBroadcastToShape) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> inputShape =
|
||||
makeShapeTorchCompatible(inputType.getShape());
|
||||
|
@ -335,13 +336,16 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
|||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value zeroIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value oneIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
|
@ -358,46 +362,65 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
|||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
} else {
|
||||
// Case of dynamic input dimension wherein the shape to broadcast will
|
||||
// yield us the dimension size of the output.
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
if (!useBroadcastToShape.empty()) {
|
||||
if (useBroadcastToShape[i])
|
||||
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
|
||||
}
|
||||
outShape.push_back(dim);
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outShape), elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(broadcastToShape.size(),
|
||||
utils::IteratorType::parallel);
|
||||
result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), input, outTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
loc, outTensor.getType(), ValueRange(), outTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
// `loopIndices` contains IV of the linalg loops which
|
||||
// would be used to extract values from the input tensor
|
||||
// later on.
|
||||
SmallVector<Value> loopIndices;
|
||||
for (size_t i = 0; i < broadcastToShape.size(); ++i) {
|
||||
if (i < diff)
|
||||
continue;
|
||||
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
}
|
||||
// `inputIndicesToExtract` contains i-th linalg loop IV if
|
||||
// the i-th input dimension is not 1, else it contains a
|
||||
// zero index.
|
||||
SmallVector<Value> inputIndicesToExtract;
|
||||
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
|
||||
if (inputShape[i] == 1) {
|
||||
inputIndicesToExtract.push_back(zeroIndex);
|
||||
} else {
|
||||
Value inputDim = getDimOp(b, loc, input, i);
|
||||
Value isEqual = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isEqual, zeroIndex, loopIndices[i]);
|
||||
inputIndicesToExtract.push_back(select);
|
||||
}
|
||||
}
|
||||
// Extract and yield the value from input tensor at
|
||||
// `inputIndicesToExtract` indices.
|
||||
Value result = b.create<tensor::ExtractOp>(
|
||||
loc, input, inputIndicesToExtract);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
|
|
|
@ -73,10 +73,10 @@ Value createElementwiseLinalgGeneric(
|
|||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
|
||||
Value input,
|
||||
SmallVector<Value> broadcastToShape,
|
||||
Value &result);
|
||||
LogicalResult
|
||||
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result,
|
||||
SmallVector<bool> useBroadcastToShape = {});
|
||||
|
||||
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
|
||||
// <?x?xf32>
|
||||
|
|
|
@ -1394,6 +1394,32 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class BroadcastDynamicDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -1, 1, -1], torch.float32, True),
|
||||
([1, -1, 1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
dim_at_index_1 = torch.ops.aten.size(x, 1)
|
||||
dim_at_index_3 = torch.ops.aten.size(x, 3)
|
||||
res = torch.ops.aten.broadcast_to(y, [1, dim_at_index_1, 1, dim_at_index_3])
|
||||
return res
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BroadcastDynamicDimModule())
|
||||
def BroadcastDynamicDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 1, 4), tu.rand(1, 1, 1, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RollModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue