[MLIR][TORCH] Add TorchToTosa lowering for aten.broadcast_to op (#1386)

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1395/head
Vivek Khandelwal 2022-09-20 22:34:51 +05:30 committed by GitHub
parent 0e2e94d542
commit 1ffd42bbde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 0 deletions

View File

@ -364,6 +364,7 @@ TOSA_PASS_SET = {
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic",
"BroadcastToIdentityCaseStaticModule_basic",
}
LTC_XFAIL_SET = {

View File

@ -2940,6 +2940,40 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
}
SmallVector<int64_t> outShape;
if (!matchPattern(op.size(), m_TorchConstantIntList(outShape)))
return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants");
SmallVector<int64_t> inputShape(selfType.getShape());
if (!llvm::equal(inputShape, outShape))
return rewriter.notifyMatchFailure(op,
"Only identity cases are supported.");
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
rewriter.getI64ArrayAttr(outShape));
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -3562,6 +3596,7 @@ public:
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -1047,6 +1047,28 @@ def BroadcastToModule_basic(module, tu: TestUtils):
# ==============================================================================
class BroadcastToIdentityCaseStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 1, 1], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [3, 1, 1])
@register_test_case(module_factory=lambda: BroadcastToIdentityCaseStaticModule())
def BroadcastToIdentityCaseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
# ==============================================================================
class RollModule(torch.nn.Module):
def __init__(self):