mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchToTosa lowering for aten.broadcast_to op (#1386)
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1395/head
parent
0e2e94d542
commit
1ffd42bbde
|
@ -364,6 +364,7 @@ TOSA_PASS_SET = {
|
||||||
"ArgmaxModule_keepDim",
|
"ArgmaxModule_keepDim",
|
||||||
"ArgmaxModule_with_dim",
|
"ArgmaxModule_with_dim",
|
||||||
"_LogSoftmaxModuleStable_basic",
|
"_LogSoftmaxModuleStable_basic",
|
||||||
|
"BroadcastToIdentityCaseStaticModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
|
|
|
@ -2940,6 +2940,40 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
|
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only tensor types with static shape are supported");
|
||||||
|
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
if (!selfElemTy.isIntOrFloat()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> outShape;
|
||||||
|
if (!matchPattern(op.size(), m_TorchConstantIntList(outShape)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"size must consist of Scalar constants");
|
||||||
|
|
||||||
|
SmallVector<int64_t> inputShape(selfType.getShape());
|
||||||
|
if (!llvm::equal(inputShape, outShape))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Only identity cases are supported.");
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||||
|
rewriter.getI64ArrayAttr(outShape));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
|
@ -3562,6 +3596,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
|
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
|
@ -1047,6 +1047,28 @@ def BroadcastToModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastToIdentityCaseStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.broadcast_to(x, [3, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BroadcastToIdentityCaseStaticModule())
|
||||||
|
def BroadcastToIdentityCaseStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RollModule(torch.nn.Module):
|
class RollModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue