torch-mlir/include/npcomp/Dialect/Torch/IR/TorchOps.td

509 lines
17 KiB
TableGen

//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_OPS
#define TORCH_OPS
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Interfaces/Traits.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
}
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
//===----------------------------------------------------------------------===//
// TorchScript `torch.nn.Module` object instantiation ops.
//===----------------------------------------------------------------------===//
def Torch_NnModuleOp : Torch_Op<"nn_module", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
let summary = "Constructs a torch.nn.Module";
let description = [{
This op is used to represent a torch.nn.Module when importing a
graph of Python objects.
This op returns a new torch.nn.Module as an SSA value, with a set of
declaratively specified properties.
Example:
```mlir
%2 = torch.nn_module {
torch.slot "b", %bool_true : !basicpy.BoolType
torch.slot "i", %num3_i64 : i64
torch.slot "f", %num : f64
torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
torch.slot "submodule", %1 : !torch.nn.Module
} : !torch.nn.Module<"my_class_name">
```
This op is tightly coupled to the `torch.class_type` op named in the
`!torch.nn.Module<"my_class_name">` type. Each slot must match precisely
with the corresponding `torch.attr` in the `torch.class_type`.
See the documentation for `torch.class_type` for information.
}];
let arguments = (ins);
let results = (outs Torch_NnModuleType:$result);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let assemblyFormat = "$region attr-dict `:` type($result)";
let extraClassDeclaration = [{
StringRef getClassName() { return getType().getClassName(); }
ClassTypeOp getClassType(::mlir::SymbolTable &symbolTable) {
return symbolTable.lookup<ClassTypeOp>(getClassName());
}
}];
}
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
let summary = "Implicit terminator for torch.nn_module";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}
def Torch_SlotOp : Torch_Op<"slot", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
let summary = "Define the value of a slot of a torch.nn.Module";
let description = [{
This op specifies that the initial value of the slot `name` of the
parent torch.nn_module should be `value`, which is allowed to be an
arbitrary Torch-compatible SSA value, including other !torch.nn.Module's.
}];
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
let results = (outs);
let assemblyFormat = [{
$name `,` $value attr-dict `:` type($value)
}];
}
//===----------------------------------------------------------------------===//
// Modeling of TorchScript class types
//===----------------------------------------------------------------------===//
def Torch_ClassTypeOp : Torch_Op<"class_type", [
Symbol,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
let summary = "Constructs a torch.ClassType";
let description = [{
Declares a class type. Class types are the types used to describe
TorchScript `torch.nn.Module`'s. The terminology "class type" is for
consistency with TorchScript (a better name in our context might be
"nn module subtype"). The `syn_name` of this op is the same string
as in the `!torch.nn.Module<"...">` type.
Example:
```mlir
// A simple empty torch.class_type, with corresponding torch.nn_module.
torch.class_type @empty {}
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
// A class type with many members.
torch.class_type @test {
torch.attr "b" : !basicpy.BoolType
torch.attr "i" : i64
torch.attr "f" : f64
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule" : !torch.nn.Module<"empty">
torch.method "method", @f
}
torch.nn_module {
// These must match the order and names in the `torch.class_type`.
torch.slot "b", %bool_true : !basicpy.BoolType
torch.slot "i", %num3_i64 : i64
torch.slot "f", %num : f64
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
} : !torch.nn.Module<"test">
```
}];
let arguments = (ins SymbolNameAttr:$sym_name);
let results = (outs);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let assemblyFormat = "$sym_name $region attr-dict";
}
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
let summary = "Implicit terminator for torch.class_type";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}
def Torch_MethodOp : Torch_Op<"method", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Declare a method of a torch.class_type";
let description = [{
This op declaratively specifies that the parent torch.class_type has a
method `name` which calls `function`. `function` is an unbound function.
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
"self" object).
If `private` is present, it indicates that external calls cannot be made
to this method.
}];
// We don't use sym_visibility because that only applies to Symbol's, and
// some of the related concepts like "nested" visibility are specific to
// symbols.
let arguments = (ins
StrAttr:$name,
FlatSymbolRefAttr:$function,
// `private` is a C++ keyword, so use `isPrivate`.
UnitAttr:$isPrivate
);
let results = (outs);
let assemblyFormat = [{
(`private` $isPrivate^)? $name `,` $function attr-dict
}];
}
def Torch_AttrOp : Torch_Op<"attr", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
]> {
let summary = "Declare an attribute of a torch.class_type";
let description = [{
This op declaratively specifies that torch.nn.Module's of the parent
torch.class_type must have an attribute `name` of type `type`.
If `private` is present, it indicates that the value of this attribute
cannot be accessed externally.
}];
// We don't use sym_visibility because that only applies to Symbol's, and
// some of the related concepts like "nested" visibility are specific to
// symbols.
let arguments = (ins
StrAttr:$name,
TypeAttr:$type,
// `private` is a C++ keyword, so use `isPrivate`
UnitAttr:$isPrivate
);
let results = (outs);
let assemblyFormat = [{
(`private` $isPrivate^)? $name `:` $type attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Global slot ops
//===----------------------------------------------------------------------===//
// TODO: Should these be in a separate dialect?
// At this point, they are fairly specific to torch types, but their get/set
// semantics follow Python.
//===----------------------------------------------------------------------===//
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
Symbol,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp">
]> {
let summary = "A slot with global storage";
let description = [{
Represents a slot with global storage. The slot semantics are the same
as Python's: getting or setting a slot is done by object identity.
The `typeBound` is a type that the contained type is a subtype of.
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
OptionalAttr<StrAttr>:$sym_visibility,
TypeAttr:$typeBound
);
let results = (outs);
let regions = (region SizedRegion<1>:$initializer);
let assemblyFormat = [{
($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)?
}];
}
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator,
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> {
let summary = "yield-like terminator for torch.global_slot initializer region";
let description = [{
The operand to this op becomes the initial value of the parent
torch.global_slot.
}];
let arguments = (ins AnyTorchType:$initialValue);
let results = (outs);
// This bulider creates an illegal op, but is needed to appease
// ensureTerminator in the default builders for SingleBlockImplicitTerminator
// on the parent torch.global_slot op.
// TODO: Have a SingleBlockExplicitTerminator trait.
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
let assemblyFormat = "$initialValue attr-dict `:` type($initialValue)";
}
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
let summary = "Get the value stored in a torch.global_slot";
let arguments = (ins
FlatSymbolRefAttr:$slot
);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$slot attr-dict `:` type($result)
}];
}
def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
let summary = "Set the value stored in a torch.global_slot";
let arguments = (ins
FlatSymbolRefAttr:$slot,
AnyTorchType:$value
);
let results = (outs);
let assemblyFormat = [{
$slot `=` $value attr-dict `:` type($value)
}];
}
//===----------------------------------------------------------------------===//
// TorchScript interpreter builtin ops.
//===----------------------------------------------------------------------===//
// These don't correspond to a `torch::jit::Operator`, so they don't appear
// in the registry and cannot be autogenerated.
// Most of these correspond 1:1 to interpreter opcodes, though some
// (like control flow being lowered to raw branches) are not directly mapped.
// See `torch/csrc/jit/runtime/instruction.h`.
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
[AllowsTypeRefinement]> {
let summary = "TorchScript prim::ListUnpack op";
let arguments = (ins AnyTorchType:$operand);
let results = (outs Variadic<AnyTorchType>:$results);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($results)
}];
}
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
let summary = "TorchScript prim::GetAttr op";
let arguments = (ins StrAttr:$name, Torch_NnModuleType:$receiver);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
}];
}
def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
let summary = "TorchScript prim::SetAttr op";
let arguments = (ins
StrAttr:$name,
Torch_NnModuleType:$receiver,
AnyTorchType:$value
);
let results = (outs);
let assemblyFormat = [{
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
}];
}
def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
let summary = "TorchScript prim::CallMethod op";
let arguments = (ins
StrAttr:$name,
Torch_NnModuleType:$receiver,
Variadic<AnyTorchType>:$operands
);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
}];
}
def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorEntryOperands"]>]> {
let summary = "TorchScript prim::Loop op";
let description = [{
This op (together with prim.Loop.condition) define a looping construct
that combines `for` and `while` behavior.
See: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
}];
let arguments = (ins
I64:$maxTripCount,
Basicpy_BoolType:$initialCondition,
Variadic<AnyTorchType>:$iterArgsInit
);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
attr-dict `:` functional-type(operands, results)
}];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
}
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop";
let description = [{
Does not correspond to any torch prim op directly (the way that they model
blocks has a built-in notion of yield-like terminator).
}];
let arguments = (ins
Basicpy_BoolType:$shouldContinue,
Variadic<AnyTorchType>:$iterArgs
);
let results = (outs);
let assemblyFormat = [{
$shouldContinue `,`
`iter` `(` ($iterArgs^ `:` type($iterArgs))? `)` attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Additional ops used to model TorchScript's Graph's / Node's.
//===----------------------------------------------------------------------===//
def Torch_DerefineOp : Torch_Op<"derefine", [
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>,
]> {
let summary = "De-refine a type";
let description = [{
In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.
This op bridges that impedance mismatch. This op allows casting a value
from one type to a type that it is a subtype of to model this behavior.
}];
let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasCanonicalizer = 1;
}
def Torch_OperatorOp : Torch_Op<"operator", [
AllowsTypeRefinement
]> {
let summary = "Opaque torch operator";
let description = [{
Represents an invocation of a `torch::jit::Operator` for which we don't
have a registered MLIR operation.
The `name` attribute contains the name that the MLIR op would have
(excluding `torch.`) if we did have it registered, which allows easy
cross referencing with `JITOperatorRegistryDump.txt`.
}];
let arguments = (ins StrAttr:$name, Variadic<AnyTorchType>:$operands);
let results = (outs Variadic<AnyTorchType>:$results);
let assemblyFormat = [{
$name `(` $operands `)` attr-dict `:` functional-type($operands, $results)
}];
}
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
AllowsTypeRefinement
]> {
let summary = "Create a `!torch.LinearParams`";
let arguments = (ins
AnyTorchTensorType:$weight,
Optional<AnyTorchTensorType>:$bias
);
let results = (outs Torch_LinearParamsType:$result);
let assemblyFormat = [{
$weight (`,` $bias^)? attr-dict `:` type($weight) (`,` type($bias)^)?
}];
}
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
AllowsTypeRefinement
]> {
let summary = "Create a per-tensor-affine quantized tensor";
let description = [{
Create a quantized tensor.
Quantization formula is:
```
Q(x, scale, zero_point) = round(x/scale + zero_point)
```
See:
https://pytorch.org/docs/stable/quantization.html#quantized-tensors
}];
let arguments = (ins
AnyTorchTensorType:$int_repr,
AnyFloat:$scale,
AnyTorchIntType:$offset
);
// TODO: Limit to quantized dtypes (e.g. !torch.qint8).
let results = (outs AnyTorchTensorType:$result);
let assemblyFormat = [{
$int_repr `,` $scale `,` $offset attr-dict
`:` type($int_repr) `,` type($scale) `,` type($offset) `->` type($result)
}];
}
#endif // TORCH_OPS