Add support for prim::Loop op.

This is a funny one. It combines a `for` and `while` loop in one op. We
will need to write some conversions to `scf`.
pull/176/head
Sean Silva 2021-03-01 15:00:32 -08:00
parent 7dfd6f697e
commit 939d36906f
6 changed files with 132 additions and 0 deletions

View File

@ -140,6 +140,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::Loop) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()), mlirRegionCreate());
mapResults(node, operation);
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
return;
}
if (kind == c10::prim::If) {
// TorchScript will already have an explicit op to determine truthiness. So
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.

View File

@ -0,0 +1,51 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
# CHECK: %[[RESULTS:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[BOOL_TRUE]], init(%[[F_INIT]]) {
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::add" %[[F_ITER]], %[[IV]] : (f64, i64) -> f64 {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[RESULTS:.*]] : f64
@mb.import_function
@torch.jit.script
def prim_Loop_forlike(n: int):
f = 0.0
for i in range(n):
f += i
return f
# CHECK-LABEL: func @prim_Loop_whilelike(
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[VAL_9:.*]] : f64
@mb.import_function
@torch.jit.script
def prim_Loop_whilelike(n: int):
f = 3.2
while f < n:
f = f * f
return f
mb.module.operation.print()
print()

View File

@ -13,6 +13,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"

View File

@ -12,6 +12,7 @@
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
@ -377,6 +378,49 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
}];
}
def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorEntryOperands"]>]> {
let summary = "TorchScript prim::Loop op";
let description = [{
This op (together with prim.Loop.condition) define a looping construct
that combines `for` and `while` behavior.
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
}];
let arguments = (ins
I64:$maxTripCount,
Basicpy_BoolType:$initialCondition,
Variadic<AnyTorchType>:$iterArgsInit
);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
attr-dict `:` functional-type(operands, results)
}];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
}
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop";
let description = [{
Does not correspond to any torch prim op directly (the way that they model
blocks has a built-in notion of yield-like terminator).
}];
let arguments = (ins Basicpy_BoolType:$shouldContinue, Variadic<AnyTorchType>:$iterArgs);
let results = (outs);
let assemblyFormat = [{
$shouldContinue `iter` `(` $iterArgs `)`
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
}];
}
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
let summary = "TorchScript prim::NumToTensor op";

View File

@ -16,6 +16,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
MLIRControlFlowInterfaces
NPCOMPBasicpyDialect
NPCOMPNumpyDialect
)

View File

@ -155,5 +155,28 @@ static LogicalResult verify(ClassTypeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// PrimLoopOp
//===----------------------------------------------------------------------===//
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0);
return iterArgsInit();
}
void PrimLoopOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
(void)operands;
if (!index.hasValue()) {
regions.emplace_back(&region(), region().getArguments().slice(1));
return;
}
assert(*index == 0);
regions.emplace_back(&region(), region().getArguments().slice(1));
regions.emplace_back(getResults());
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"