mirror of https://github.com/llvm/torch-mlir
Add support for prim::Loop op.
This is a funny one. It combines a `for` and `while` loop in one op. We will need to write some conversions to `scf`.pull/176/head
parent
7dfd6f697e
commit
939d36906f
|
@ -140,6 +140,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::Loop) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.Loop", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()), mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::If) {
|
||||
// TorchScript will already have an explicit op to determine truthiness. So
|
||||
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @prim_Loop_forlike(
|
||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
|
||||
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
|
||||
# CHECK: %[[RESULTS:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[BOOL_TRUE]], init(%[[F_INIT]]) {
|
||||
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::add" %[[F_ITER]], %[[IV]] : (f64, i64) -> f64 {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||
# CHECK: return %[[RESULTS:.*]] : f64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_Loop_forlike(n: int):
|
||||
f = 0.0
|
||||
for i in range(n):
|
||||
f += i
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: func @prim_Loop_whilelike(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
|
||||
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
|
||||
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
|
||||
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
||||
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||
# CHECK: return %[[VAL_9:.*]] : f64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_Loop_whilelike(n: int):
|
||||
f = 3.2
|
||||
while f < n:
|
||||
f = f * f
|
||||
return f
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
|
@ -377,6 +378,49 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorEntryOperands"]>]> {
|
||||
let summary = "TorchScript prim::Loop op";
|
||||
let description = [{
|
||||
This op (together with prim.Loop.condition) define a looping construct
|
||||
that combines `for` and `while` behavior.
|
||||
|
||||
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64:$maxTripCount,
|
||||
Basicpy_BoolType:$initialCondition,
|
||||
Variadic<AnyTorchType>:$iterArgsInit
|
||||
);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
}
|
||||
|
||||
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
||||
Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
|
||||
let summary = "yield-like terminator for torch.prim.Loop";
|
||||
let description = [{
|
||||
Does not correspond to any torch prim op directly (the way that they model
|
||||
blocks has a built-in notion of yield-like terminator).
|
||||
}];
|
||||
|
||||
let arguments = (ins Basicpy_BoolType:$shouldContinue, Variadic<AnyTorchType>:$iterArgs);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$shouldContinue `iter` `(` $iterArgs `)`
|
||||
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
||||
let summary = "TorchScript prim::NumToTensor op";
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRControlFlowInterfaces
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
)
|
||||
|
|
|
@ -155,5 +155,28 @@ static LogicalResult verify(ClassTypeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimLoopOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
|
||||
assert(index == 0);
|
||||
return iterArgsInit();
|
||||
}
|
||||
|
||||
void PrimLoopOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
(void)operands;
|
||||
|
||||
if (!index.hasValue()) {
|
||||
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||
return;
|
||||
}
|
||||
assert(*index == 0);
|
||||
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||
regions.emplace_back(getResults());
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
Loading…
Reference in New Issue