[MLIR][TORCH] Add E2E support for aten.expand_as op

This commit decomposes `aten.expand_as` op into `aten.broadcast_to` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/676/head snapshot-20220321.337
Vivek Khandelwal 2022-03-14 13:42:37 +05:30
parent 63fb1e5aad
commit 4c0cd5c23d
6 changed files with 85 additions and 3 deletions

View File

@ -1452,3 +1452,42 @@ class BincountMinlengthModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BincountMinlengthModule())
def BincountMinlengthModule_basic(module, tu: TestUtils):
module.forward(torch.randint(5, (20,)))
# ==============================================================================
class ExpandAsFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, 1, 1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.expand_as(x, y)
@register_test_case(module_factory=lambda: ExpandAsFloatModule())
def ExpandAsFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1), tu.rand(3, 4, 5))
class ExpandAsIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 1], torch.int64, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.expand_as(x, y)
@register_test_case(module_factory=lambda: ExpandAsIntModule())
def ExpandAsIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (1, 1, 1)), torch.randint(200, (4, 5, 6)))

View File

@ -1429,6 +1429,23 @@ public:
};
} // namespace
namespace {
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExpandAsOp op,
PatternRewriter &rewriter) const override {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.other());
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
sizeList);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1534,6 +1551,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenFullLikeOp>();
patterns.add<DecomposeAtenIndexPutOp>(context);
target.addIllegalOp<AtenIndexPutOp>();
patterns.add<DecomposeAtenExpandAsOp>(context);
target.addIllegalOp<AtenExpandAsOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -515,9 +515,9 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(
op)) {
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenIndexTensorOp,
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = operands[0]->getValue().dtype;

View File

@ -1077,6 +1077,10 @@ module {
}
return %6 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.expand_as"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.broadcast_to"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -530,6 +530,9 @@ def atenembedding(weight: List[int], indices: List[int], padding_idx: int = -
def atenexpand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
return upstream_shape_helpers.expand(self, size)
def atenexpand_as(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_helpers.unary(other)
def atenbroadcast_to(self: List[int], size: List[int]) -> List[int]:
return upstream_shape_helpers.expand(self, size)

View File

@ -662,3 +662,20 @@ func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtens
%0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.expand_as(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[OTHER]], %[[INT1]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[OTHER]], %[[INT2]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}