mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Only unroll prim loop-like ops within a `torch.shape.calculate` region (#3812)
Reports a match failure for the pattern `FullyUnrollPrimLoop` when the
loop op is not in a region defined by a `torch.shape.calculate` op.
This is needed to avoid unrolling prim loops generated by ONNX IR, since
we are applying shape refinement in the
`torch-onnx-to-torch-backend-pipeline` introduced in fa4794d
.
See also the discussion in
<https://github.com/iree-org/iree/pull/18867#discussion_r1811101655>
pull/3790/head
parent
aca33f1742
commit
55ff110dc2
|
@ -32,9 +32,6 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// TODO: Only unroll inside the shape calculation region.
|
|
||||||
// Maybe do this by only applying patterns and folding greedily on the ops
|
|
||||||
// inside the region + the shape.calculate op itself?
|
|
||||||
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
|
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -42,6 +39,12 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
|
// Only unroll loops if they are contained in a shape calculate region.
|
||||||
|
Region *region = op->getParentRegion();
|
||||||
|
Operation *parentOp = region->getParentOp();
|
||||||
|
if (!parentOp || !isa<Torch::ShapeCalculateOp>(parentOp))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Loop is not contained in a shape calculation region.");
|
||||||
if (!op.isForLike())
|
if (!op.isForLike())
|
||||||
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
|
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
|
||||||
int64_t maxTripCount;
|
int64_t maxTripCount;
|
||||||
|
|
|
@ -152,6 +152,23 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch
|
||||||
return %0 : !torch.vtensor
|
return %0 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region(
|
||||||
|
// CHECK: %[[LOOP:.*]] = torch.prim.Loop
|
||||||
|
func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%0 = torch.prim.Loop %arg2, %true, init(%arg0) {
|
||||||
|
^bb0(%arg3: !torch.int, %arg4: !torch.vtensor):
|
||||||
|
%1 = torch.shape.calculate {
|
||||||
|
torch.shape.calculate.yield %arg4 : !torch.vtensor
|
||||||
|
} shapes {
|
||||||
|
torch.prim.Print(%arg3) : !torch.int
|
||||||
|
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
|
||||||
|
} : !torch.vtensor
|
||||||
|
torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor)
|
||||||
|
} : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
|
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
|
||||||
|
|
Loading…
Reference in New Issue