Add `torch.prim.If`

This removes the use of `scf.if`, which required laundering back and
forth between `i1` and `!torch.bool` in the frontend. We will eventually
lower this op to `scf.if`, but this results in a cleaner IR and layering
at the frontend.
pull/228/head
Sean Silva 2021-06-16 10:23:26 -07:00
parent 784156a998
commit 8860b5c55d
9 changed files with 184 additions and 24 deletions

View File

@ -173,21 +173,16 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
if (kind == c10::prim::If) {
// TorchScript will already have an explicit op to determine truthiness. So
// all we need to do here is launder !torch.bool to i1 for `scf.if`.
MlirOperation pred = createMlirOperationAtEnd(
appendToBlock, "torch.to_i1", loc, mlirIntegerTypeGet(context, 1),
lookupMappedValue(node->input()));
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()),
resultTypes, mlirRegionCreate(), mlirRegionCreate());
mapResults(node, operation);
auto createTerminator =
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "scf.yield", loc,
appendToBlock, "torch.prim.If.yield", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
mlirRegionAppendOwnedBlock(

View File

@ -27,7 +27,7 @@ class TestModule(torch.nn.Module):
# CHECK-LABEL: func private @{{.*}}TestModule.forward
def forward(self, b: bool):
# Modules with the same class can be selected between.
# CHECK: %[[MOD:.*]] = scf.if
# CHECK: %[[MOD:.*]] = torch.prim.If
s = self.s1 if b else self.s2
# CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] ()
# CHECK: return %[[N]]

View File

@ -13,10 +13,10 @@ mb = torch_mlir.ModuleBuilder()
@mb.import_function
@torch.jit.script
def f(b: bool, i: int):
# elif is modeled as a nested if
# CHECK: scf.if{{.*}}{
# elif is modeled as a nested if, so we only need to do cursory checking.
# CHECK: torch.prim.If {{.*}} {
# CHECK: } else {
# CHECK: scf.if{{.*}}{
# CHECK: torch.prim.If {{.*}} {
# CHECK: } else {
# CHECK: }
# CHECK: }

View File

@ -9,38 +9,39 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# Note: The "if without else" case is handled by yielding None from the
# else branch and making all defined values optional, so no special handling
# is needed.
# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
@mb.import_function
@torch.jit.script
def prim_If(b: bool, i: int):
# CHECK: %[[I1:.*]] = torch.to_i1 %[[B]]
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (i64) {
# CHECK: %[[ADD:.*]] = torch.aten.add.int %[[I]], %[[I]]
# CHECK: scf.yield %[[ADD]] : i64
# CHECK: torch.prim.If.yield %[[ADD]] : i64
# CHECK: } else {
# CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[I]], %[[I]]
# CHECK: scf.yield %[[MUL]] : i64
# CHECK: torch.prim.If.yield %[[MUL]] : i64
# CHECK: }
# CHECK: return %[[RES:.*]] : i64
if b:
return i + i
else:
return i * i
# elif is modeled as a nested if, so no need to specially test it here.
# CHECK-LABEL: func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[PRED:.*]] = torch.to_i1 %[[B]]
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (!torch.optional<i64>) {
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<i64>
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
# CHECK: torch.prim.If.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
# CHECK: } else {
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 to !torch.optional<i64>
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
# CHECK: torch.prim.If.yield %[[I_DEREFINED]] : !torch.optional<i64>
# CHECK: }
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
@mb.import_function

View File

@ -49,12 +49,11 @@ def prim_RaiseException():
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[C3:.*]] = torch.constant.int 3 : i64
# CHECK: %[[IS_NONE:.*]] = torch.aten.__is__ %[[ARG]], %[[NONE]] : !torch.optional<i64>, !torch.none -> !torch.bool
# CHECK: %[[COND:.*]] = torch.to_i1 %[[IS_NONE]]
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
# CHECK: scf.yield %[[C3]] : i64
# CHECK: %[[RESULT:.*]] = torch.prim.If %[[IS_NONE]] -> (i64) {
# CHECK: torch.prim.If.yield %[[C3]] : i64
# CHECK: } else {
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<i64> -> i64
# CHECK: scf.yield %[[CASTED]] : i64
# CHECK: torch.prim.If.yield %[[CASTED]] : i64
# CHECK: }
# CHECK: return %[[RESULT:.*]] : i64
@mb.import_function

View File

@ -453,6 +453,51 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
}];
}
def Torch_PrimIfOp : Torch_Op<"prim.If", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
let summary = "TorchScript prim::If op";
let description = [{
This op (together with prim.If.yield) define a conditional control flow
construct. It is analogous to `scf.if` for MLIR folks that are familiar
with that. The main differences from that op are:
- `!torch.bool` condition value.
- The "else" region is always present. This is reflective of invariants of
the TorchScript IR.
- No special prettiness for the "no yielded values" case. These are
interesting for modeling mostly-non-SSA programs, but TorchScript IR
is already in SSA form.
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#if
}];
let arguments = (ins Torch_BoolType:$condition);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parsePrimIfOp(parser, result); }];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let hasCanonicalizer = 1;
}
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> {
let summary = "yield-like terminator for torch.prim.If";
let description = [{
Does not correspond to any torch prim op directly (the way that they model
blocks has a built-in notion of yield-like terminator).
}];
let arguments = (ins
Variadic<AnyTorchType>:$results
);
let results = (outs);
let assemblyFormat = [{
attr-dict ($results^ `:` type($results))?
}];
}
//===----------------------------------------------------------------------===//
// Ops corresponding to prim::Constant

View File

@ -198,6 +198,99 @@ void PrimLoopOp::getSuccessorRegions(
regions.emplace_back(getResults());
}
//===----------------------------------------------------------------------===//
// PrimIfOp
//===----------------------------------------------------------------------===//
static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
// Create the regions.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder();
OpAsmParser::OperandType cond;
Type boolType = builder.getType<Torch::BoolType>();
if (parser.parseOperand(cond) ||
parser.resolveOperand(cond, boolType, result.operands))
return failure();
// Parse results type list.
if (parser.parseArrowTypeList(result.types))
return failure();
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
// Parse the 'else' region.
if (parser.parseKeyword("else"))
return failure();
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
static void print(OpAsmPrinter &p, PrimIfOp op) {
p << PrimIfOp::getOperationName() << " " << op.condition();
p << " -> (" << op.getResultTypes() << ")";
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
p << " else";
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs());
}
void PrimIfOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index.hasValue()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
// If the condition is constant, we can give a more precise answer.
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
Region *executedRegion =
condAttr.getValue().isOneValue() ? &thenRegion() : &elseRegion();
regions.push_back(RegionSuccessor(executedRegion));
return;
}
// If the condition isn't constant, both regions may be executed.
regions.push_back(RegionSuccessor(&thenRegion()));
regions.push_back(RegionSuccessor(&elseRegion()));
return;
}
/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Region &region, ValueRange blockArgs = {}) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
Block *block = &region.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.mergeBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}
void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// If the condition is constant, delete the dead branch and inline the live
// branch.
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
auto constantBool = op.condition().getDefiningOp<Torch::ConstantBoolOp>();
if (!constantBool)
return rewriter.notifyMatchFailure(op, "non-constant condition");
replaceOpWithRegion(
rewriter, op, constantBool.value() ? op.thenRegion() : op.elseRegion());
return success();
});
}
//===----------------------------------------------------------------------===//
// DerefineOp
//===----------------------------------------------------------------------===//

View File

@ -144,3 +144,19 @@ func @torch.constant.bool$constantlike() -> (!torch.bool, !torch.bool, !torch.bo
%2 = torch.constant.bool false
return %0, %1, %2 : !torch.bool, !torch.bool, !torch.bool
}
// CHECK-LABEL: func @torch.prim.If$erase_dead_branch(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK-NEXT: %[[RET:.*]] = torch.aten.add.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
// CHECK-NEXT: return %[[RET]] : i64
func @torch.prim.If$erase_dead_branch(%arg0: i64) -> i64 {
%true = torch.constant.bool true
%0 = torch.prim.If %true -> (i64) {
%1 = torch.aten.add.int %arg0, %arg0 : i64, i64 -> i64
torch.prim.If.yield %1 : i64
} else {
%1 = torch.aten.mul.int %arg0, %arg0 : i64, i64 -> i64
torch.prim.If.yield %1 : i64
}
return %0 : i64
}

View File

@ -62,6 +62,17 @@ func @derefine(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
return %0 : !torch.optional<!torch.tensor>
}
func @torch.prim.If(%arg0: !torch.bool, %arg1: i64) -> i64 {
%0 = torch.prim.If %arg0 -> (i64) {
%1 = torch.aten.add.int %arg1, %arg1 : i64, i64 -> i64
torch.prim.If.yield %1 : i64
} else {
%1 = torch.aten.mul.int %arg1, %arg1 : i64, i64 -> i64
torch.prim.If.yield %1 : i64
}
return %0 : i64
}
// CHECK: %true = torch.constant.bool true
%true = torch.constant.bool true
// CHECK: %false = torch.constant.bool false