mirror of https://github.com/llvm/torch-mlir
Add `torch.prim.If`
This removes the use of `scf.if`, which required laundering back and forth between `i1` and `!torch.bool` in the frontend. We will eventually lower this op to `scf.if`, but this results in a cleaner IR and layering at the frontend.pull/228/head
parent
784156a998
commit
8860b5c55d
|
@ -173,21 +173,16 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
if (kind == c10::prim::If) {
|
||||
// TorchScript will already have an explicit op to determine truthiness. So
|
||||
// all we need to do here is launder !torch.bool to i1 for `scf.if`.
|
||||
MlirOperation pred = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.to_i1", loc, mlirIntegerTypeGet(context, 1),
|
||||
lookupMappedValue(node->input()));
|
||||
std::vector<MlirType> resultTypes =
|
||||
getMlirTypesFromValues(loc, node->outputs());
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
||||
appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()),
|
||||
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
auto createTerminator =
|
||||
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "scf.yield", loc,
|
||||
appendToBlock, "torch.prim.If.yield", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
|
|
|
@ -27,7 +27,7 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK-LABEL: func private @{{.*}}TestModule.forward
|
||||
def forward(self, b: bool):
|
||||
# Modules with the same class can be selected between.
|
||||
# CHECK: %[[MOD:.*]] = scf.if
|
||||
# CHECK: %[[MOD:.*]] = torch.prim.If
|
||||
s = self.s1 if b else self.s2
|
||||
# CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] ()
|
||||
# CHECK: return %[[N]]
|
||||
|
|
|
@ -13,10 +13,10 @@ mb = torch_mlir.ModuleBuilder()
|
|||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def f(b: bool, i: int):
|
||||
# elif is modeled as a nested if
|
||||
# CHECK: scf.if{{.*}}{
|
||||
# elif is modeled as a nested if, so we only need to do cursory checking.
|
||||
# CHECK: torch.prim.If {{.*}} {
|
||||
# CHECK: } else {
|
||||
# CHECK: scf.if{{.*}}{
|
||||
# CHECK: torch.prim.If {{.*}} {
|
||||
# CHECK: } else {
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
|
|
|
@ -9,38 +9,39 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# Note: The "if without else" case is handled by yielding None from the
|
||||
# else branch and making all defined values optional, so no special handling
|
||||
# is needed.
|
||||
|
||||
# CHECK-LABEL: @__torch__.prim_If(
|
||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_If(b: bool, i: int):
|
||||
# CHECK: %[[I1:.*]] = torch.to_i1 %[[B]]
|
||||
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
|
||||
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (i64) {
|
||||
# CHECK: %[[ADD:.*]] = torch.aten.add.int %[[I]], %[[I]]
|
||||
# CHECK: scf.yield %[[ADD]] : i64
|
||||
# CHECK: torch.prim.If.yield %[[ADD]] : i64
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[I]], %[[I]]
|
||||
# CHECK: scf.yield %[[MUL]] : i64
|
||||
# CHECK: torch.prim.If.yield %[[MUL]] : i64
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RES:.*]] : i64
|
||||
if b:
|
||||
return i + i
|
||||
else:
|
||||
return i * i
|
||||
# elif is modeled as a nested if, so no need to specially test it here.
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_If_derefine(
|
||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
|
||||
# CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
# CHECK: %[[PRED:.*]] = torch.to_i1 %[[B]]
|
||||
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
|
||||
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (!torch.optional<i64>) {
|
||||
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<i64>
|
||||
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: torch.prim.If.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 to !torch.optional<i64>
|
||||
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: torch.prim.If.yield %[[I_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
|
||||
@mb.import_function
|
||||
|
|
|
@ -49,12 +49,11 @@ def prim_RaiseException():
|
|||
# CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
# CHECK: %[[C3:.*]] = torch.constant.int 3 : i64
|
||||
# CHECK: %[[IS_NONE:.*]] = torch.aten.__is__ %[[ARG]], %[[NONE]] : !torch.optional<i64>, !torch.none -> !torch.bool
|
||||
# CHECK: %[[COND:.*]] = torch.to_i1 %[[IS_NONE]]
|
||||
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
||||
# CHECK: scf.yield %[[C3]] : i64
|
||||
# CHECK: %[[RESULT:.*]] = torch.prim.If %[[IS_NONE]] -> (i64) {
|
||||
# CHECK: torch.prim.If.yield %[[C3]] : i64
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<i64> -> i64
|
||||
# CHECK: scf.yield %[[CASTED]] : i64
|
||||
# CHECK: torch.prim.If.yield %[[CASTED]] : i64
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RESULT:.*]] : i64
|
||||
@mb.import_function
|
||||
|
|
|
@ -453,6 +453,51 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
|
||||
let summary = "TorchScript prim::If op";
|
||||
let description = [{
|
||||
This op (together with prim.If.yield) define a conditional control flow
|
||||
construct. It is analogous to `scf.if` for MLIR folks that are familiar
|
||||
with that. The main differences from that op are:
|
||||
|
||||
- `!torch.bool` condition value.
|
||||
- The "else" region is always present. This is reflective of invariants of
|
||||
the TorchScript IR.
|
||||
- No special prettiness for the "no yielded values" case. These are
|
||||
interesting for modeling mostly-non-SSA programs, but TorchScript IR
|
||||
is already in SSA form.
|
||||
|
||||
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#if
|
||||
}];
|
||||
|
||||
let arguments = (ins Torch_BoolType:$condition);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parsePrimIfOp(parser, result); }];
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
|
||||
Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> {
|
||||
let summary = "yield-like terminator for torch.prim.If";
|
||||
let description = [{
|
||||
Does not correspond to any torch prim op directly (the way that they model
|
||||
blocks has a built-in notion of yield-like terminator).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict ($results^ `:` type($results))?
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops corresponding to prim::Constant
|
||||
|
|
|
@ -198,6 +198,99 @@ void PrimLoopOp::getSuccessorRegions(
|
|||
regions.emplace_back(getResults());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Create the regions.
|
||||
result.regions.reserve(2);
|
||||
Region *thenRegion = result.addRegion();
|
||||
Region *elseRegion = result.addRegion();
|
||||
|
||||
auto &builder = parser.getBuilder();
|
||||
OpAsmParser::OperandType cond;
|
||||
Type boolType = builder.getType<Torch::BoolType>();
|
||||
if (parser.parseOperand(cond) ||
|
||||
parser.resolveOperand(cond, boolType, result.operands))
|
||||
return failure();
|
||||
// Parse results type list.
|
||||
if (parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
// Parse the 'then' region.
|
||||
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
// Parse the 'else' region.
|
||||
if (parser.parseKeyword("else"))
|
||||
return failure();
|
||||
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
// Parse the optional attribute list.
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PrimIfOp op) {
|
||||
p << PrimIfOp::getOperationName() << " " << op.condition();
|
||||
p << " -> (" << op.getResultTypes() << ")";
|
||||
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
|
||||
p << " else";
|
||||
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
|
||||
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
void PrimIfOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The `then` and the `else` region branch back to the parent operation.
|
||||
if (index.hasValue()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
||||
// If the condition is constant, we can give a more precise answer.
|
||||
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||
Region *executedRegion =
|
||||
condAttr.getValue().isOneValue() ? &thenRegion() : &elseRegion();
|
||||
regions.push_back(RegionSuccessor(executedRegion));
|
||||
return;
|
||||
}
|
||||
|
||||
// If the condition isn't constant, both regions may be executed.
|
||||
regions.push_back(RegionSuccessor(&thenRegion()));
|
||||
regions.push_back(RegionSuccessor(&elseRegion()));
|
||||
return;
|
||||
}
|
||||
|
||||
/// Replaces the given op with the contents of the given single-block region,
|
||||
/// using the operands of the block terminator to replace operation results.
|
||||
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
|
||||
Region ®ion, ValueRange blockArgs = {}) {
|
||||
assert(llvm::hasSingleElement(region) && "expected single-region block");
|
||||
Block *block = ®ion.front();
|
||||
Operation *terminator = block->getTerminator();
|
||||
ValueRange results = terminator->getOperands();
|
||||
rewriter.mergeBlockBefore(block, op, blockArgs);
|
||||
rewriter.replaceOp(op, results);
|
||||
rewriter.eraseOp(terminator);
|
||||
}
|
||||
|
||||
void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// If the condition is constant, delete the dead branch and inline the live
|
||||
// branch.
|
||||
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
||||
auto constantBool = op.condition().getDefiningOp<Torch::ConstantBoolOp>();
|
||||
if (!constantBool)
|
||||
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
||||
replaceOpWithRegion(
|
||||
rewriter, op, constantBool.value() ? op.thenRegion() : op.elseRegion());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DerefineOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -144,3 +144,19 @@ func @torch.constant.bool$constantlike() -> (!torch.bool, !torch.bool, !torch.bo
|
|||
%2 = torch.constant.bool false
|
||||
return %0, %1, %2 : !torch.bool, !torch.bool, !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.If$erase_dead_branch(
|
||||
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
||||
// CHECK-NEXT: %[[RET:.*]] = torch.aten.add.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
// CHECK-NEXT: return %[[RET]] : i64
|
||||
func @torch.prim.If$erase_dead_branch(%arg0: i64) -> i64 {
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.prim.If %true -> (i64) {
|
||||
%1 = torch.aten.add.int %arg0, %arg0 : i64, i64 -> i64
|
||||
torch.prim.If.yield %1 : i64
|
||||
} else {
|
||||
%1 = torch.aten.mul.int %arg0, %arg0 : i64, i64 -> i64
|
||||
torch.prim.If.yield %1 : i64
|
||||
}
|
||||
return %0 : i64
|
||||
}
|
||||
|
|
|
@ -62,6 +62,17 @@ func @derefine(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
|
|||
return %0 : !torch.optional<!torch.tensor>
|
||||
}
|
||||
|
||||
func @torch.prim.If(%arg0: !torch.bool, %arg1: i64) -> i64 {
|
||||
%0 = torch.prim.If %arg0 -> (i64) {
|
||||
%1 = torch.aten.add.int %arg1, %arg1 : i64, i64 -> i64
|
||||
torch.prim.If.yield %1 : i64
|
||||
} else {
|
||||
%1 = torch.aten.mul.int %arg1, %arg1 : i64, i64 -> i64
|
||||
torch.prim.If.yield %1 : i64
|
||||
}
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// CHECK: %true = torch.constant.bool true
|
||||
%true = torch.constant.bool true
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
|
|
Loading…
Reference in New Issue