Add a torch.kernel_call op and associated predicates.

pull/64/head
Stella Laurenzo 2020-09-29 14:17:34 -07:00 committed by Stella Laurenzo
parent ba03ecc652
commit 3d74337be0
6 changed files with 97 additions and 9 deletions

View File

@ -10,6 +10,8 @@
#define TORCH_BASE
include "mlir/IR/OpBase.td"
include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.td"
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
def Torch_Dialect : Dialect {
let name = "torch";
@ -38,4 +40,60 @@ def Torch_Dialect : Dialect {
}];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
// Torch has a fairly advanced and featureful Tensor type, and some of the
// semantics are important to preserve in a compilation context. In the future,
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
// and interop are well served by existing tensor-like types, which are
// specifically permitted. Typically, on import, constraints are fairly loose
// and based on how the program is captured. Settling on and refining to
// specific types is done as part of lowering.
//
// While lowering it is useful to be able to distinguish between mutable and
// immutable tensors:
// - Mutable tensors can alias.
// - Mutable tensors can be a view over another mutable tensor.
// - Mutable tensors act as if reference counted and exist for the lifetime
// of any reference or derived view.
// Conversely, immutable tensors:
// - Are normal SSA values representing the contents of the tensor.
// - Cannot alias.
// - Cannot be a view of any mutable value.
// - Have undefined lifetimes.
//
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
// however, when lowering to specific ops, further constraints are introduced,
// necessitating copies, loads, and stores to be inserted to bridge worlds.
def AnyTorchImmutableTensor : AnyTypeOf<[
// Normal MLIR immutable tensors.
AnyTensor,
], "allowable torch immutable tensor">;
def AnyTorchMutableTensor : AnyTypeOf<[
// "Numpy-style" mutable NDArray. While not offering the full generality
// of a Torch tensor, it models the same access patterns and implies the
// same aliasing as Torch tensors.
Numpy_NdArrayType,
], "allowable torch mutable tensor">;
def AnyTorchTensorType : AnyTypeOf<[
AnyTorchImmutableTensor,
AnyTorchMutableTensor,
], "Any tensor type legal to pass to a Torch kernel">;
def AnyScalar : AnyTypeOf<[
AnySignedInteger,
Basicpy_BoolType,
Basicpy_StrType,
Basicpy_NoneType,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchType : AnyTypeOf<[
AnyScalar,
AnyTorchTensorType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_BASE

View File

@ -15,8 +15,28 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
}
def Torch_DummyOp : Torch_Op<"dummy"> {
let summary = "Dummy placeholder op until more is defined";
def Torch_KernelCall : Torch_Op<"kernel_call"> {
let summary = "Calls a Torch custom kernel";
let description = [{
Torch kernel calls are matched by the runtime based on signature, including
the fully qualified kernel name (i.e. "namespace::name") and the tuple of
argument types. This op models such an invocation.
Caveat: When interacting with the PyTorch boxed interpreter stack, the
arguments follow a convention that the top of stack is the first argument.
}];
let arguments = (ins
StrAttr:$kernel_name,
Variadic<AnyTorchType>:$args
);
let results = (outs
Variadic<AnyTorchType>:$results
);
let assemblyFormat = [{
$kernel_name $args attr-dict `:` functional-type($args, results)
}];
}
#endif // TORCH_OPS

View File

@ -14,4 +14,6 @@ add_mlir_dialect_library(NPCOMPTorchDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
NPCOMPBasicpyDialect
NPCOMPNumpyDialect
)

View File

@ -8,8 +8,16 @@
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Basicpy;
using namespace mlir::NPCOMP::Numpy;
using namespace mlir::NPCOMP::Torch;
#define GET_OP_CLASSES

View File

@ -1,7 +0,0 @@
// RUN: npcomp-opt %s | FileCheck %s
func @dummy() {
// CHECK: "torch.dummy"
"torch.dummy"() : () -> ()
return
}

View File

@ -0,0 +1,7 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
// CHECK: %0 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32>
%1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>)
return %1 : tensor<*xf32>
}