mirror of https://github.com/llvm/torch-mlir
Add support for prim::Loop op.
This is a funny one. It combines a `for` and `while` loop in one op. We will need to write some conversions to `scf`.pull/176/head
parent
7dfd6f697e
commit
939d36906f
|
@ -140,6 +140,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (kind == c10::prim::Loop) {
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "torch.prim.Loop", loc,
|
||||||
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
lookupMappedValues(node->inputs()), mlirRegionCreate());
|
||||||
|
mapResults(node, operation);
|
||||||
|
mlirRegionAppendOwnedBlock(
|
||||||
|
mlirOperationGetRegion(operation, 0),
|
||||||
|
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::If) {
|
if (kind == c10::prim::If) {
|
||||||
// TorchScript will already have an explicit op to determine truthiness. So
|
// TorchScript will already have an explicit op to determine truthiness. So
|
||||||
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @prim_Loop_forlike(
|
||||||
|
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
|
||||||
|
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
|
||||||
|
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
|
||||||
|
# CHECK: %[[RESULTS:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[BOOL_TRUE]], init(%[[F_INIT]]) {
|
||||||
|
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||||
|
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::add" %[[F_ITER]], %[[IV]] : (f64, i64) -> f64 {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||||
|
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||||
|
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||||
|
# CHECK: return %[[RESULTS:.*]] : f64
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_Loop_forlike(n: int):
|
||||||
|
f = 0.0
|
||||||
|
for i in range(n):
|
||||||
|
f += i
|
||||||
|
return f
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @prim_Loop_whilelike(
|
||||||
|
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
|
||||||
|
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
|
||||||
|
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
|
||||||
|
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||||
|
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
||||||
|
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||||
|
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||||
|
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||||
|
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||||
|
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||||
|
# CHECK: return %[[VAL_9:.*]] : f64
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_Loop_whilelike(n: int):
|
||||||
|
f = 3.2
|
||||||
|
while f < n:
|
||||||
|
f = f * f
|
||||||
|
return f
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
|
|
||||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Torch_Dialect, mnemonic, traits> {
|
: Op<Torch_Dialect, mnemonic, traits> {
|
||||||
|
@ -377,6 +378,49 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorEntryOperands"]>]> {
|
||||||
|
let summary = "TorchScript prim::Loop op";
|
||||||
|
let description = [{
|
||||||
|
This op (together with prim.Loop.condition) define a looping construct
|
||||||
|
that combines `for` and `while` behavior.
|
||||||
|
|
||||||
|
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I64:$maxTripCount,
|
||||||
|
Basicpy_BoolType:$initialCondition,
|
||||||
|
Variadic<AnyTorchType>:$iterArgsInit
|
||||||
|
);
|
||||||
|
let results = (outs Variadic<AnyTorchType>:$results);
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
|
||||||
|
attr-dict `:` functional-type(operands, results)
|
||||||
|
}];
|
||||||
|
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
||||||
|
Terminator,
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
|
||||||
|
let summary = "yield-like terminator for torch.prim.Loop";
|
||||||
|
let description = [{
|
||||||
|
Does not correspond to any torch prim op directly (the way that they model
|
||||||
|
blocks has a built-in notion of yield-like terminator).
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Basicpy_BoolType:$shouldContinue, Variadic<AnyTorchType>:$iterArgs);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$shouldContinue `iter` `(` $iterArgs `)`
|
||||||
|
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
||||||
let summary = "TorchScript prim::NumToTensor op";
|
let summary = "TorchScript prim::NumToTensor op";
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
|
MLIRControlFlowInterfaces
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -155,5 +155,28 @@ static LogicalResult verify(ClassTypeOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrimLoopOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
|
||||||
|
assert(index == 0);
|
||||||
|
return iterArgsInit();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrimLoopOp::getSuccessorRegions(
|
||||||
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
(void)operands;
|
||||||
|
|
||||||
|
if (!index.hasValue()) {
|
||||||
|
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
assert(*index == 0);
|
||||||
|
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||||
|
regions.emplace_back(getResults());
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
|
|
Loading…
Reference in New Issue