From 8860b5c55d40e78ac68808f0e3de5f991e4eeb45 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 16 Jun 2021 10:23:26 -0700 Subject: [PATCH] Add `torch.prim.If` This removes the use of `scf.if`, which required laundering back and forth between `i1` and `!torch.bool` in the frontend. We will eventually lower this op to `scf.if`, but this results in a cleaner IR and layering at the frontend. --- .../pytorch/csrc/builder/node_importer.cpp | 9 +- .../test/ivalue_import/submodules-select.py | 2 +- frontends/pytorch/test/node_import/elif.py | 6 +- frontends/pytorch/test/node_import/if.py | 19 ++-- frontends/pytorch/test/node_import/prim.py | 7 +- include/npcomp/Dialect/Torch/IR/TorchOps.td | 45 +++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 93 +++++++++++++++++++ test/Dialect/Torch/canonicalize.mlir | 16 ++++ test/Dialect/Torch/ops.mlir | 11 +++ 9 files changed, 184 insertions(+), 24 deletions(-) diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index cbfd29a22..b8b85950f 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -173,21 +173,16 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } if (kind == c10::prim::If) { - // TorchScript will already have an explicit op to determine truthiness. So - // all we need to do here is launder !torch.bool to i1 for `scf.if`. - MlirOperation pred = createMlirOperationAtEnd( - appendToBlock, "torch.to_i1", loc, mlirIntegerTypeGet(context, 1), - lookupMappedValue(node->input())); std::vector resultTypes = getMlirTypesFromValues(loc, node->outputs()); MlirOperation operation = createMlirOperationAtEnd( - appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0), + appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()), resultTypes, mlirRegionCreate(), mlirRegionCreate()); mapResults(node, operation); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( - appendToBlock, "scf.yield", loc, + appendToBlock, "torch.prim.If.yield", loc, derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); }; mlirRegionAppendOwnedBlock( diff --git a/frontends/pytorch/test/ivalue_import/submodules-select.py b/frontends/pytorch/test/ivalue_import/submodules-select.py index 088a3c388..4236611e6 100644 --- a/frontends/pytorch/test/ivalue_import/submodules-select.py +++ b/frontends/pytorch/test/ivalue_import/submodules-select.py @@ -27,7 +27,7 @@ class TestModule(torch.nn.Module): # CHECK-LABEL: func private @{{.*}}TestModule.forward def forward(self, b: bool): # Modules with the same class can be selected between. - # CHECK: %[[MOD:.*]] = scf.if + # CHECK: %[[MOD:.*]] = torch.prim.If s = self.s1 if b else self.s2 # CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] () # CHECK: return %[[N]] diff --git a/frontends/pytorch/test/node_import/elif.py b/frontends/pytorch/test/node_import/elif.py index 2e13f8acb..5d68d7e59 100644 --- a/frontends/pytorch/test/node_import/elif.py +++ b/frontends/pytorch/test/node_import/elif.py @@ -13,10 +13,10 @@ mb = torch_mlir.ModuleBuilder() @mb.import_function @torch.jit.script def f(b: bool, i: int): - # elif is modeled as a nested if - # CHECK: scf.if{{.*}}{ + # elif is modeled as a nested if, so we only need to do cursory checking. + # CHECK: torch.prim.If {{.*}} { # CHECK: } else { - # CHECK: scf.if{{.*}}{ + # CHECK: torch.prim.If {{.*}} { # CHECK: } else { # CHECK: } # CHECK: } diff --git a/frontends/pytorch/test/node_import/if.py b/frontends/pytorch/test/node_import/if.py index f491b746b..8afe4663e 100644 --- a/frontends/pytorch/test/node_import/if.py +++ b/frontends/pytorch/test/node_import/if.py @@ -9,38 +9,39 @@ import torch_mlir mb = torch_mlir.ModuleBuilder() +# Note: The "if without else" case is handled by yielding None from the +# else branch and making all defined values optional, so no special handling +# is needed. + # CHECK-LABEL: @__torch__.prim_If( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: i64) -> i64 { @mb.import_function @torch.jit.script def prim_If(b: bool, i: int): - # CHECK: %[[I1:.*]] = torch.to_i1 %[[B]] - # CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) { + # CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (i64) { # CHECK: %[[ADD:.*]] = torch.aten.add.int %[[I]], %[[I]] - # CHECK: scf.yield %[[ADD]] : i64 + # CHECK: torch.prim.If.yield %[[ADD]] : i64 # CHECK: } else { # CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[I]], %[[I]] - # CHECK: scf.yield %[[MUL]] : i64 + # CHECK: torch.prim.If.yield %[[MUL]] : i64 # CHECK: } # CHECK: return %[[RES:.*]] : i64 if b: return i + i else: return i * i - # elif is modeled as a nested if, so no need to specially test it here. # CHECK-LABEL: func @__torch__.prim_If_derefine( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional { # CHECK: %[[NONE:.*]] = torch.constant.none -# CHECK: %[[PRED:.*]] = torch.to_i1 %[[B]] -# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional) { +# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (!torch.optional) { # CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional -# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional +# CHECK: torch.prim.If.yield %[[NONE_DEREFINED]] : !torch.optional # CHECK: } else { # CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 to !torch.optional -# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional +# CHECK: torch.prim.If.yield %[[I_DEREFINED]] : !torch.optional # CHECK: } # CHECK: return %[[RES:.*]] : !torch.optional @mb.import_function diff --git a/frontends/pytorch/test/node_import/prim.py b/frontends/pytorch/test/node_import/prim.py index da67efdd5..003cae43d 100644 --- a/frontends/pytorch/test/node_import/prim.py +++ b/frontends/pytorch/test/node_import/prim.py @@ -49,12 +49,11 @@ def prim_RaiseException(): # CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[C3:.*]] = torch.constant.int 3 : i64 # CHECK: %[[IS_NONE:.*]] = torch.aten.__is__ %[[ARG]], %[[NONE]] : !torch.optional, !torch.none -> !torch.bool -# CHECK: %[[COND:.*]] = torch.to_i1 %[[IS_NONE]] -# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) { -# CHECK: scf.yield %[[C3]] : i64 +# CHECK: %[[RESULT:.*]] = torch.prim.If %[[IS_NONE]] -> (i64) { +# CHECK: torch.prim.If.yield %[[C3]] : i64 # CHECK: } else { # CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional -> i64 -# CHECK: scf.yield %[[CASTED]] : i64 +# CHECK: torch.prim.If.yield %[[CASTED]] : i64 # CHECK: } # CHECK: return %[[RESULT:.*]] : i64 @mb.import_function diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 2a9146d37..94c80e061 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -453,6 +453,51 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [ }]; } +def Torch_PrimIfOp : Torch_Op<"prim.If", [ + DeclareOpInterfaceMethods]> { + let summary = "TorchScript prim::If op"; + let description = [{ + This op (together with prim.If.yield) define a conditional control flow + construct. It is analogous to `scf.if` for MLIR folks that are familiar + with that. The main differences from that op are: + + - `!torch.bool` condition value. + - The "else" region is always present. This is reflective of invariants of + the TorchScript IR. + - No special prettiness for the "no yielded values" case. These are + interesting for modeling mostly-non-SSA programs, but TorchScript IR + is already in SSA form. + + See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#if + }]; + + let arguments = (ins Torch_BoolType:$condition); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion); + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parsePrimIfOp(parser, result); }]; + let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; + let hasCanonicalizer = 1; +} + +def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [ + Terminator, + HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> { + let summary = "yield-like terminator for torch.prim.If"; + let description = [{ + Does not correspond to any torch prim op directly (the way that they model + blocks has a built-in notion of yield-like terminator). + }]; + + let arguments = (ins + Variadic:$results + ); + let results = (outs); + + let assemblyFormat = [{ + attr-dict ($results^ `:` type($results))? + }]; +} //===----------------------------------------------------------------------===// // Ops corresponding to prim::Constant diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 82564437b..10521892f 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -198,6 +198,99 @@ void PrimLoopOp::getSuccessorRegions( regions.emplace_back(getResults()); } +//===----------------------------------------------------------------------===// +// PrimIfOp +//===----------------------------------------------------------------------===// + +static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) { + // Create the regions. + result.regions.reserve(2); + Region *thenRegion = result.addRegion(); + Region *elseRegion = result.addRegion(); + + auto &builder = parser.getBuilder(); + OpAsmParser::OperandType cond; + Type boolType = builder.getType(); + if (parser.parseOperand(cond) || + parser.resolveOperand(cond, boolType, result.operands)) + return failure(); + // Parse results type list. + if (parser.parseArrowTypeList(result.types)) + return failure(); + // Parse the 'then' region. + if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + // Parse the 'else' region. + if (parser.parseKeyword("else")) + return failure(); + if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +static void print(OpAsmPrinter &p, PrimIfOp op) { + p << PrimIfOp::getOperationName() << " " << op.condition(); + p << " -> (" << op.getResultTypes() << ")"; + p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false); + p << " else"; + p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false); + + p.printOptionalAttrDict(op->getAttrs()); +} + +void PrimIfOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // The `then` and the `else` region branch back to the parent operation. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // If the condition is constant, we can give a more precise answer. + if (auto condAttr = operands.front().dyn_cast_or_null()) { + Region *executedRegion = + condAttr.getValue().isOneValue() ? &thenRegion() : &elseRegion(); + regions.push_back(RegionSuccessor(executedRegion)); + return; + } + + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&thenRegion())); + regions.push_back(RegionSuccessor(&elseRegion())); + return; +} + +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + +void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // If the condition is constant, delete the dead branch and inline the live + // branch. + patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { + auto constantBool = op.condition().getDefiningOp(); + if (!constantBool) + return rewriter.notifyMatchFailure(op, "non-constant condition"); + replaceOpWithRegion( + rewriter, op, constantBool.value() ? op.thenRegion() : op.elseRegion()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // DerefineOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 91c928f5f..3c4aa0a08 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -144,3 +144,19 @@ func @torch.constant.bool$constantlike() -> (!torch.bool, !torch.bool, !torch.bo %2 = torch.constant.bool false return %0, %1, %2 : !torch.bool, !torch.bool, !torch.bool } + +// CHECK-LABEL: func @torch.prim.If$erase_dead_branch( +// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { +// CHECK-NEXT: %[[RET:.*]] = torch.aten.add.int %[[ARG]], %[[ARG]] : i64, i64 -> i64 +// CHECK-NEXT: return %[[RET]] : i64 +func @torch.prim.If$erase_dead_branch(%arg0: i64) -> i64 { + %true = torch.constant.bool true + %0 = torch.prim.If %true -> (i64) { + %1 = torch.aten.add.int %arg0, %arg0 : i64, i64 -> i64 + torch.prim.If.yield %1 : i64 + } else { + %1 = torch.aten.mul.int %arg0, %arg0 : i64, i64 -> i64 + torch.prim.If.yield %1 : i64 + } + return %0 : i64 +} diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 2357a1cd4..d2bf155cb 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -62,6 +62,17 @@ func @derefine(%arg0: !torch.tensor) -> !torch.optional { return %0 : !torch.optional } +func @torch.prim.If(%arg0: !torch.bool, %arg1: i64) -> i64 { + %0 = torch.prim.If %arg0 -> (i64) { + %1 = torch.aten.add.int %arg1, %arg1 : i64, i64 -> i64 + torch.prim.If.yield %1 : i64 + } else { + %1 = torch.aten.mul.int %arg1, %arg1 : i64, i64 -> i64 + torch.prim.If.yield %1 : i64 + } + return %0 : i64 +} + // CHECK: %true = torch.constant.bool true %true = torch.constant.bool true // CHECK: %false = torch.constant.bool false