2020-09-29 03:02:35 +08:00
|
|
|
//===-------------------------------------------------------*- tablegen -*-===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#ifndef TORCH_OPS
|
|
|
|
#define TORCH_OPS
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
2020-10-23 14:31:34 +08:00
|
|
|
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
2021-01-28 08:35:44 +08:00
|
|
|
include "mlir/IR/SymbolInterfaces.td"
|
2021-03-02 07:00:32 +08:00
|
|
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
2021-03-02 09:24:15 +08:00
|
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
2020-09-29 03:02:35 +08:00
|
|
|
|
|
|
|
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
|
|
|
: Op<Torch_Dialect, mnemonic, traits> {
|
|
|
|
}
|
|
|
|
|
2020-10-23 14:31:34 +08:00
|
|
|
// TODO: Add alias mapping from the signature and use it to implement the
|
|
|
|
// effects interface (since whether the kernel_call has side effects is
|
|
|
|
// dependent on its metadata).
|
|
|
|
def Torch_KernelCallOp : Torch_Op<"kernel_call", [
|
|
|
|
DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
2020-09-30 05:17:34 +08:00
|
|
|
let summary = "Calls a Torch custom kernel";
|
|
|
|
let description = [{
|
|
|
|
Torch kernel calls are matched by the runtime based on signature, including
|
|
|
|
the fully qualified kernel name (i.e. "namespace::name") and the tuple of
|
|
|
|
argument types. This op models such an invocation.
|
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins
|
2020-10-23 14:31:34 +08:00
|
|
|
StrAttr:$kernelName,
|
|
|
|
Variadic<AnyTorchType>:$args,
|
|
|
|
StrArrayAttr:$sigArgTypes,
|
|
|
|
StrArrayAttr:$sigRetTypes,
|
|
|
|
BoolAttr:$sigIsVararg,
|
|
|
|
BoolAttr:$sigIsVarret,
|
|
|
|
BoolAttr:$sigIsMutable
|
|
|
|
// TODO: Add alias mapping.
|
2020-09-30 05:17:34 +08:00
|
|
|
);
|
|
|
|
let results = (outs
|
|
|
|
Variadic<AnyTorchType>:$results
|
|
|
|
);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2020-10-23 14:31:34 +08:00
|
|
|
$kernelName $args `:` functional-type($args, results) attr-dict
|
2020-09-30 05:17:34 +08:00
|
|
|
}];
|
2020-09-29 03:02:35 +08:00
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-02-18 03:28:51 +08:00
|
|
|
// TorchScript `torch.nn.Module` object instantiation ops.
|
2021-01-28 08:35:44 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
2021-02-18 03:28:51 +08:00
|
|
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
2021-01-28 08:35:44 +08:00
|
|
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
|
|
|
let summary = "Constructs a torch.nn.Module";
|
|
|
|
let description = [{
|
|
|
|
This op is used to represent a torch.nn.Module when importing a
|
|
|
|
graph of Python objects.
|
|
|
|
|
|
|
|
This op returns a new torch.nn.Module as an SSA value, with a set of
|
|
|
|
declaratively specified properties.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```mlir
|
2021-02-18 03:28:51 +08:00
|
|
|
%2 = torch.nn_module {
|
|
|
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
|
|
|
torch.slot "i", %num3_i64 : i64
|
|
|
|
torch.slot "f", %num : f64
|
|
|
|
torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
torch.slot "submodule", %1 : !torch.nn.Module
|
|
|
|
} : !torch.nn.Module<"my_class_name">
|
2021-01-28 08:35:44 +08:00
|
|
|
```
|
2021-02-18 03:28:51 +08:00
|
|
|
|
|
|
|
This op is tightly coupled to the `torch.class_type` op named in the
|
|
|
|
`!torch.nn.Module<"my_class_name">` type. Each slot must match precisely
|
|
|
|
with the corresponding `torch.attr` in the `torch.class_type`.
|
|
|
|
See the documentation for `torch.class_type` for information.
|
2021-01-28 08:35:44 +08:00
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins);
|
|
|
|
let results = (outs Torch_NnModuleType:$result);
|
|
|
|
let regions = (region SizedRegion<1>:$region);
|
|
|
|
let verifier = "return ::verify(*this);";
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
let assemblyFormat = "$region attr-dict `:` type($result)";
|
|
|
|
|
|
|
|
let extraClassDeclaration = [{
|
|
|
|
StringRef getClassName() { return getType().getClassName(); }
|
|
|
|
}];
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
|
|
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
|
|
|
let summary = "Implicit terminator for torch.nn_module";
|
|
|
|
|
|
|
|
let arguments = (ins);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = "attr-dict";
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
def Torch_SlotOp : Torch_Op<"slot", [
|
2021-01-28 08:35:44 +08:00
|
|
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
2021-02-18 03:28:51 +08:00
|
|
|
let summary = "Define the value of a slot of a torch.nn.Module";
|
2021-01-28 08:35:44 +08:00
|
|
|
let description = [{
|
2021-02-18 03:28:51 +08:00
|
|
|
This op specifies that the initial value of the slot `name` of the
|
|
|
|
parent torch.nn_module should be `value`, which is allowed to be an
|
|
|
|
arbitrary Torch-compatible SSA value, including other !torch.nn.Module's.
|
2021-01-28 08:35:44 +08:00
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$name `,` $value attr-dict `:` type($value)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Modeling of TorchScript class types
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
|
|
|
Symbol,
|
|
|
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
|
|
|
|
let summary = "Constructs a torch.ClassType";
|
|
|
|
let description = [{
|
|
|
|
Declares a class type. Class types are the types used to describe
|
|
|
|
TorchScript `torch.nn.Module`'s. The terminology "class type" is for
|
|
|
|
consistency with TorchScript (a better name in our context might be
|
|
|
|
"nn module subtype"). The `syn_name` of this op is the same string
|
|
|
|
as in the `!torch.nn.Module<"...">` type.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```mlir
|
|
|
|
// A simple empty torch.class_type, with corresponding torch.nn_module.
|
|
|
|
torch.class_type @empty {}
|
|
|
|
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
|
|
|
|
|
|
|
|
// A class type with many members.
|
|
|
|
torch.class_type @test {
|
|
|
|
torch.attr "b" : !basicpy.BoolType
|
|
|
|
torch.attr "i" : i64
|
|
|
|
torch.attr "f" : f64
|
|
|
|
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
torch.attr "submodule" : !torch.nn.Module<"empty">
|
|
|
|
torch.method "method", @f
|
|
|
|
}
|
|
|
|
torch.nn_module {
|
|
|
|
// These must match the order and names in the `torch.class_type`.
|
|
|
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
|
|
|
torch.slot "i", %num3_i64 : i64
|
|
|
|
torch.slot "f", %num : f64
|
|
|
|
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
|
|
|
} : !torch.nn.Module<"test">
|
|
|
|
```
|
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins SymbolNameAttr:$sym_name);
|
|
|
|
let results = (outs);
|
|
|
|
let regions = (region SizedRegion<1>:$region);
|
|
|
|
let verifier = "return ::verify(*this);";
|
|
|
|
|
|
|
|
let assemblyFormat = "$sym_name $region attr-dict";
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
|
|
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
|
|
|
|
let summary = "Implicit terminator for torch.class_type";
|
|
|
|
|
|
|
|
let arguments = (ins);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = "attr-dict";
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
def Torch_MethodOp : Torch_Op<"method", [
|
2021-02-18 03:28:51 +08:00
|
|
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
|
2021-01-28 08:35:44 +08:00
|
|
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
|
|
|
]> {
|
2021-02-18 03:28:51 +08:00
|
|
|
let summary = "Declare a method of a torch.class_type";
|
2021-01-28 08:35:44 +08:00
|
|
|
let description = [{
|
2021-02-18 03:28:51 +08:00
|
|
|
This op declaratively specifies that the parent torch.class_type has a
|
2021-01-28 08:35:44 +08:00
|
|
|
method `name` which calls `function`. `function` is an unbound function.
|
|
|
|
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
|
|
|
|
"self" object).
|
2021-02-20 08:21:21 +08:00
|
|
|
|
|
|
|
If `private` is present, it indicates that external calls cannot be made
|
|
|
|
to this method.
|
2021-01-28 08:35:44 +08:00
|
|
|
}];
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
// We don't use sym_visibility because that only applies to Symbol's, and
|
|
|
|
// some of the related concepts like "nested" visibility are specific to
|
|
|
|
// symbols.
|
|
|
|
let arguments = (ins
|
|
|
|
StrAttr:$name,
|
|
|
|
FlatSymbolRefAttr:$function,
|
|
|
|
// `private` is a C++ keyword, so use `isPrivate`.
|
|
|
|
UnitAttr:$isPrivate
|
|
|
|
);
|
2021-01-28 08:35:44 +08:00
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-20 08:21:21 +08:00
|
|
|
(`private` $isPrivate^)? $name `,` $function attr-dict
|
2021-01-28 08:35:44 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
def Torch_AttrOp : Torch_Op<"attr", [
|
|
|
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
|
|
|
|
]> {
|
|
|
|
let summary = "Declare an attribute of a torch.class_type";
|
|
|
|
let description = [{
|
|
|
|
This op declaratively specifies that torch.nn.Module's of the parent
|
|
|
|
torch.class_type must have an attribute `name` of type `type`.
|
2021-02-20 08:21:21 +08:00
|
|
|
|
|
|
|
If `private` is present, it indicates that the value of this attribute
|
|
|
|
cannot be accessed externally.
|
2021-02-18 03:28:51 +08:00
|
|
|
}];
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
// We don't use sym_visibility because that only applies to Symbol's, and
|
|
|
|
// some of the related concepts like "nested" visibility are specific to
|
|
|
|
// symbols.
|
|
|
|
let arguments = (ins
|
|
|
|
StrAttr:$name,
|
|
|
|
TypeAttr:$type,
|
|
|
|
// `private` is a C++ keyword, so use `isPrivate`
|
|
|
|
UnitAttr:$isPrivate
|
|
|
|
);
|
2021-02-18 03:28:51 +08:00
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-20 08:21:21 +08:00
|
|
|
(`private` $isPrivate^)? $name `:` $type attr-dict
|
2021-02-18 03:28:51 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Global slot ops
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO: Should these be in a separate dialect?
|
|
|
|
// At this point, they are fairly specific to torch types, but their get/set
|
|
|
|
// semantics follow Python.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
|
|
|
Symbol,
|
|
|
|
IsolatedFromAbove,
|
2021-02-26 07:54:51 +08:00
|
|
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp">
|
2021-02-18 03:28:51 +08:00
|
|
|
]> {
|
|
|
|
let summary = "A slot with global storage";
|
|
|
|
let description = [{
|
|
|
|
Represents a slot with global storage. The slot semantics are the same
|
|
|
|
as Python's: getting or setting a slot is done by object identity.
|
2021-02-20 08:21:21 +08:00
|
|
|
|
|
|
|
The `typeBound` is a type that the contained type is a subtype of.
|
2021-02-18 03:28:51 +08:00
|
|
|
}];
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
let arguments = (ins
|
|
|
|
SymbolNameAttr:$sym_name,
|
|
|
|
OptionalAttr<StrAttr>:$sym_visibility,
|
|
|
|
TypeAttr:$typeBound
|
|
|
|
);
|
2021-02-18 03:28:51 +08:00
|
|
|
let results = (outs);
|
2021-02-26 07:54:51 +08:00
|
|
|
let regions = (region SizedRegion<1>:$initializer);
|
2021-02-18 03:28:51 +08:00
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-26 07:54:51 +08:00
|
|
|
($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)?
|
2021-02-18 03:28:51 +08:00
|
|
|
}];
|
2021-02-26 07:54:51 +08:00
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
|
2021-02-26 07:54:51 +08:00
|
|
|
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
|
|
|
Terminator,
|
|
|
|
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> {
|
|
|
|
let summary = "yield-like terminator for torch.global_slot initializer region";
|
|
|
|
let description = [{
|
|
|
|
The operand to this op becomes the initial value of the parent
|
|
|
|
torch.global_slot.
|
2021-02-18 03:28:51 +08:00
|
|
|
}];
|
2021-02-26 07:54:51 +08:00
|
|
|
|
|
|
|
let arguments = (ins AnyTorchType:$initialValue);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
// This bulider creates an illegal op, but is needed to appease
|
|
|
|
// ensureTerminator in the default builders for SingleBlockImplicitTerminator
|
|
|
|
// on the parent torch.global_slot op.
|
|
|
|
// TODO: Have a SingleBlockExplicitTerminator trait.
|
2021-03-09 21:58:03 +08:00
|
|
|
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
|
2021-02-26 07:54:51 +08:00
|
|
|
|
|
|
|
let assemblyFormat = "$initialValue attr-dict `:` type($initialValue)";
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
|
|
|
|
let summary = "Get the value stored in a torch.global_slot";
|
|
|
|
|
|
|
|
let arguments = (ins
|
|
|
|
FlatSymbolRefAttr:$slot
|
|
|
|
);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$slot attr-dict `:` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
|
|
|
let summary = "Set the value stored in a torch.global_slot";
|
|
|
|
|
|
|
|
let arguments = (ins
|
|
|
|
FlatSymbolRefAttr:$slot,
|
|
|
|
AnyTorchType:$value
|
|
|
|
);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$slot `=` $value attr-dict `:` type($value)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-02-02 09:59:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TorchScript `prim::` ops.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
|
|
|
let summary = "TorchScript prim::GetAttr op";
|
|
|
|
|
|
|
|
let arguments = (ins StrAttr:$name, Torch_NnModuleType:$receiver);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-18 03:28:51 +08:00
|
|
|
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
|
2021-02-02 09:59:42 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
|
|
|
|
let summary = "TorchScript prim::SetAttr op";
|
|
|
|
|
|
|
|
let arguments = (ins
|
|
|
|
StrAttr:$name,
|
|
|
|
Torch_NnModuleType:$receiver,
|
|
|
|
AnyTorchType:$value
|
|
|
|
);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-18 03:28:51 +08:00
|
|
|
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
|
2021-02-02 09:59:42 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
|
|
|
|
let summary = "TorchScript prim::CallMethod op";
|
|
|
|
|
2021-02-06 06:54:04 +08:00
|
|
|
let arguments = (ins
|
|
|
|
StrAttr:$name,
|
|
|
|
Torch_NnModuleType:$receiver,
|
|
|
|
Variadic<AnyTorchType>:$operands
|
|
|
|
);
|
2021-02-02 09:59:42 +08:00
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
2021-02-18 03:28:51 +08:00
|
|
|
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
|
2021-02-06 06:54:04 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-02 05:47:50 +08:00
|
|
|
def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
|
2021-02-06 06:54:04 +08:00
|
|
|
let summary = "TorchScript prim::Print op";
|
|
|
|
|
|
|
|
let arguments = (ins Variadic<AnyTorchType>:$operands);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
`(` $operands `)` attr-dict `:` type($operands)
|
2021-02-02 09:59:42 +08:00
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-02 07:00:32 +08:00
|
|
|
def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
|
|
|
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorEntryOperands"]>]> {
|
|
|
|
let summary = "TorchScript prim::Loop op";
|
|
|
|
let description = [{
|
|
|
|
This op (together with prim.Loop.condition) define a looping construct
|
|
|
|
that combines `for` and `while` behavior.
|
|
|
|
|
|
|
|
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
|
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins
|
|
|
|
I64:$maxTripCount,
|
|
|
|
Basicpy_BoolType:$initialCondition,
|
|
|
|
Variadic<AnyTorchType>:$iterArgsInit
|
|
|
|
);
|
|
|
|
let results = (outs Variadic<AnyTorchType>:$results);
|
|
|
|
let regions = (region SizedRegion<1>:$region);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
|
|
|
|
attr-dict `:` functional-type(operands, results)
|
|
|
|
}];
|
|
|
|
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
|
|
|
Terminator,
|
|
|
|
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
|
|
|
|
let summary = "yield-like terminator for torch.prim.Loop";
|
|
|
|
let description = [{
|
|
|
|
Does not correspond to any torch prim op directly (the way that they model
|
|
|
|
blocks has a built-in notion of yield-like terminator).
|
|
|
|
}];
|
|
|
|
|
|
|
|
let arguments = (ins Basicpy_BoolType:$shouldContinue, Variadic<AnyTorchType>:$iterArgs);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$shouldContinue `iter` `(` $iterArgs `)`
|
|
|
|
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-02-26 08:35:29 +08:00
|
|
|
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
|
|
|
let summary = "TorchScript prim::NumToTensor op";
|
|
|
|
|
|
|
|
let arguments = (ins AnyTorchNumberType:$num);
|
|
|
|
let results = (outs AnyTorchTensorType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$num attr-dict `:` type($num) `->` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-02 05:47:50 +08:00
|
|
|
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", []> {
|
|
|
|
let summary = "TorchScript prim::RaiseException op";
|
|
|
|
|
|
|
|
// TODO: Error messages suggest that any exception derived from BaseException
|
|
|
|
// is allowed at the Python level, but they seem to just be strings at the
|
|
|
|
// IR level.
|
|
|
|
let arguments = (ins Basicpy_BytesType:$errorMsg);
|
|
|
|
let results = (outs);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$errorMsg attr-dict
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
|
|
|
|
let summary = "TorchScript prim::Uninitialized op";
|
|
|
|
|
|
|
|
let arguments = (ins);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
attr-dict `:` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-02 09:24:15 +08:00
|
|
|
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", [
|
|
|
|
NoSideEffect
|
|
|
|
]> {
|
2021-03-02 07:26:57 +08:00
|
|
|
let summary = "TorchScript prim::unchecked_cast op";
|
2021-03-02 09:24:15 +08:00
|
|
|
let description = [{
|
|
|
|
Refine a type to one of its subtypes.
|
|
|
|
|
|
|
|
For example, refine a type that was only statically known to be
|
|
|
|
Optional[T] to a T when we obtain static information that guarantees it.
|
|
|
|
|
|
|
|
The key observation here is that Optional[T] does not have a corresponding
|
|
|
|
runtime type (i.e. `c10::IValue` subclass). It represents a set of possible
|
|
|
|
concrete types which for `Optional[T]` is either `None` or a concrete
|
|
|
|
subtype of `T` (which in the simplest case is just `T`). In particular,
|
|
|
|
at runtime there is no way to distinguish `Optional[int]` from
|
|
|
|
`Optional[Optional[int]]`, because both are either `None` or `int`.
|
|
|
|
This differs from C++ std::optional.
|
|
|
|
|
|
|
|
The best documentation of this op is inspection of the code in
|
|
|
|
`torch/csrc/jit/frontend/ir_emitter.cpp`.
|
|
|
|
}];
|
|
|
|
|
|
|
|
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
|
|
|
let arguments = (ins AnyTorchType:$operand);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$operand attr-dict `:` type($operand) `->` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Additional ops used to model TorchScript's Graph's / Node's.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
def Torch_DerefineOp : Torch_Op<"derefine", [
|
|
|
|
NoSideEffect
|
|
|
|
]> {
|
|
|
|
let summary = "De-refine a type";
|
|
|
|
let description = [{
|
|
|
|
In terms of IR structure, TorchScript allows types to vary in many
|
|
|
|
circumstances where MLIR requires pointer-identical types. In particular,
|
|
|
|
it is valid to pass any subtype in place of a type. For example, if an
|
|
|
|
`Optional[int]` is required somewhere in the IR, it is legal to pass a
|
|
|
|
value of just `int` (but not the other way around; see
|
|
|
|
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
|
|
|
|
type.
|
|
|
|
|
|
|
|
This op bridges that impedance mismatch. This op allows casting a value
|
|
|
|
from one type to a type that it is a subtype of to model this behavior.
|
|
|
|
}];
|
|
|
|
|
|
|
|
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
2021-03-02 07:26:57 +08:00
|
|
|
let arguments = (ins AnyTorchType:$operand);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$operand attr-dict `:` type($operand) `->` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-03 06:39:48 +08:00
|
|
|
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", []> {
|
|
|
|
let summary = "TorchScript prim::ListUnpack op";
|
|
|
|
|
|
|
|
let arguments = (ins AnyTorchType:$operand);
|
|
|
|
let results = (outs Variadic<AnyTorchType>:$results);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$operand attr-dict `:` type($operand) `->` type($results)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
|
|
|
def Torch_PrimTupleUnpackOp: Torch_Op<"prim.TupleUnpack", []> {
|
|
|
|
let summary = "TorchScript prim::TupleUnpack op";
|
|
|
|
|
|
|
|
let arguments = (ins AnyTorchType:$operand);
|
|
|
|
let results = (outs Variadic<AnyTorchType>:$results);
|
|
|
|
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$operand attr-dict `:` type($operand) `->` type($results)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-03 07:26:03 +08:00
|
|
|
def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", []> {
|
|
|
|
let summary = "TorchScript prim::TupleIndex op";
|
|
|
|
let arguments = (ins
|
|
|
|
AnyTorchType:$operand,
|
|
|
|
AnyTorchNumberType:$idx);
|
|
|
|
let results = (outs AnyTorchType:$result);
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$operand `,` $idx attr-dict `:` type($operand) `,` type($idx) `->` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2021-03-11 08:41:18 +08:00
|
|
|
def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
|
|
|
let summary = "TorchScript prim::dtype op";
|
|
|
|
let arguments = (ins AnyTorchTensorType:$tensor);
|
|
|
|
let results = (outs AnyTorchNumberType:$result);
|
|
|
|
let assemblyFormat = [{
|
|
|
|
$tensor attr-dict `:` type($tensor) `->` type($result)
|
|
|
|
}];
|
|
|
|
}
|
|
|
|
|
2020-09-29 03:02:35 +08:00
|
|
|
#endif // TORCH_OPS
|