Add remaining pieces to capture full example models.

* Adds Basicpy List, Tuple, Dict types and plumbs through C API.
* Started debugging the issues around aten::conv2d capture, but a PyTorch bug is suspected.
* Was able to manually verify that the basic conv2d forward test captures correctly with a workaround.
* Need to resolve some printing issues upstream and move these tests to an integration test target (they take ~seconds to run).
pull/88/head
Stella Laurenzo 2020-10-16 17:38:07 -07:00
parent 81119aa0a1
commit 029815152e
15 changed files with 404 additions and 102 deletions

View File

@ -31,7 +31,11 @@ using c10::Stack;
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
// enum value, so we define both. We can get rid of both once we have our
// own key.
#define ACAP_DISPATCH_KEY PrivateUse1
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
// through the autograd keys.
// https://github.com/llvm/mlir-npcomp/issues/86
// #define ACAP_DISPATCH_KEY AutogradPrivateUse3
#define ACAP_DISPATCH_KEY PrivateUse3
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
std::list<AcapController::Activation> &
@ -74,8 +78,8 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
// Exclude recursive dispatch in order to print tensor.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
std::stringstream msg;
msg << "Cannot return a tensor that is not from the capture context: ";
msg << tensor;
msg << "Cannot return a tensor that is not from the capture context: "
<< tensor;
throw std::invalid_argument(msg.str());
}
@ -85,8 +89,7 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
MlirLocation loc = getCurrentLocation();
OperationStateHolder s("std.return", loc);
mlirOperationStateAddOperands(&s.state, returnsValues.size(),
returnsValues.data());
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(
s.createOperation());
funcBuilder->rewriteFuncReturnTypes(returnsTypes);
@ -161,7 +164,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
"kernel_name",
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
mlirOperationStateAddAttributes(&stateHolder.state, 1, &kernelNameAttr);
mlirOperationStateAddAttributes(stateHolder, 1, &kernelNameAttr);
// Map arguments to operands.
// This must be accumulated into the OperationState prior to re-dispatch
@ -179,8 +182,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
}
operands.push_back(mlirValue);
}
mlirOperationStateAddOperands(&stateHolder.state, operands.size(),
operands.data());
mlirOperationStateAddOperands(stateHolder, operands.size(), operands.data());
// Invoke the original kernel.
redispatch(opHandle, stack);
@ -205,7 +207,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
resultIndexToTensorMap.emplace_back(resultIndex, returnIt->toTensor());
}
}
mlirOperationStateAddResults(&stateHolder.state, resultTypes.size(),
mlirOperationStateAddResults(stateHolder, resultTypes.size(),
resultTypes.data());
// Create operation.
@ -244,9 +246,19 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return funcBuilder->getBoolConstant(loc, ival.toBool());
}
if (ival.isList()) {
auto list = ival.toList();
llvm::SmallVector<MlirValue, 4> elements;
for (c10::IValue element : list) {
elements.push_back(mapIValueToMlirValue(loc, element));
}
return funcBuilder->buildConstantList(loc, elements);
}
if (ival.isNone()) {
return funcBuilder->getNoneConstant(loc);
}
return {nullptr};
// TODO: Implement mappings for the whole set (relevant to this use case):
// _(None)
// _(Tensor)
// _(Double)
// _(Int)
@ -277,6 +289,12 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
if (ival.isList()) {
return npcompListTypeGet(funcBuilder->getContext());
}
if (ival.isNone()) {
return npcompNoneTypeGet(funcBuilder->getContext());
}
return {nullptr};
}
@ -357,3 +375,8 @@ TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
}
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
}

View File

@ -17,8 +17,8 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
MlirAttribute value) {
OperationStateHolder s("std.constant", loc);
MlirNamedAttribute valueAttr = mlirNamedAttributeGet("value", value);
mlirOperationStateAddResults(&s.state, 1, &type);
mlirOperationStateAddAttributes(&s.state, 1, &valueAttr);
mlirOperationStateAddResults(s, 1, &type);
mlirOperationStateAddAttributes(s, 1, &valueAttr);
return s.createOperation();
}
@ -170,6 +170,14 @@ MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
return getGeneralConstant(loc, value);
}
MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
OperationStateHolder state{"basicpy.singleton", loc};
MlirType noneType = npcompNoneTypeGet(context);
mlirOperationStateAddResults(state, 1, &noneType);
MlirOperation op = state.createOperation();
return insertConstantOp(op);
}
MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
MlirAttribute value) {
MlirType valueType = mlirAttributeGetType(value);
@ -177,3 +185,14 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
MlirValue constValue = insertConstantOp(constOp);
return constValue;
}
MlirValue
FuncBuilder::buildConstantList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(context);
OperationStateHolder state{"basicpy.build_list", loc};
mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data());
MlirOperation op = state.createOperation();
return insertConstantOp(op);
}

View File

@ -31,15 +31,16 @@ public:
}
}
operator MlirOperationState *() { return &state; }
MlirOperation createOperation() {
assert(owned && "cannot createOperation on unowned state");
owned = false;
return mlirOperationCreate(&state);
}
MlirOperationState state;
private:
MlirOperationState state;
bool owned = true;
};
@ -123,10 +124,19 @@ public:
/// Gets a bool constant value.
MlirValue getBoolConstant(MlirLocation loc, bool v);
/// Gets a None constant value.
MlirValue getNoneConstant(MlirLocation loc);
/// Gets a general constant value representing the given value
/// attribute.
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
/// Builds a list with the given elements (derived from constants).
/// The resulting list is inserted into the "constant section" of the
/// function.
MlirValue buildConstantList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements);
private:
FuncBuilder(MlirContext context, MlirOperation funcOp,
BlockBuilder entryBlock)

View File

@ -11,8 +11,8 @@ import torch.nn.functional as F
import torch_mlir
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/79
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
class ResA(nn.Module):
@ -54,4 +54,6 @@ with mb.capture_function("resa", [inputs]) as f:
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
print(mb.module)
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)

View File

@ -1,50 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
# XFAIL: *
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model.weight.data = model.weight.clone()
ref_model.bias.data = model.bias.clone()
softmax = torch.nn.LogSoftmax(dim=1)
loss = torch.nn.NLLLoss()
tensor = torch.randn(N, Cin, h, w)
with mb.capture_function("@conv2d_fwd", [tensor]) as f:
result = model(tensor)
f.returns([result])
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
ref_target = target.clone()
with mb.capture_function("@conv2d_backward", [result, target]) as f:
test_loss = loss(softmax(result), target)
f.returns([test_loss.backward()])
# CHECK-LABEL: func @conv2d_fwd
# TODO: Add checks when passing
# CHECK-LABEL: func @conv2d_backward
# TODO: Update checks when passing
# NO-CHECK: aten.convolution_overrideable
# NO-CHECK: aten._log_softmax
# NO-CHECK: aten.nll_loss2d_forward
print(mb.module)

View File

@ -0,0 +1,41 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model.weight.data = model.weight.clone()
ref_model.bias.data = model.bias.clone()
softmax = torch.nn.LogSoftmax(dim=1)
loss = torch.nn.NLLLoss()
tensor = torch.randn(N, Cin, h, w)
with mb.capture_function("conv2d_fwd", [tensor]) as f:
result = model(tensor)
f.returns([result])
# CHECK-LABEL: func @conv2d_fwd
# CHECK-SAME: (%arg0: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
# CHECK: %[[P1:.*]] = numpy.create_array_from_tensor %cst : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
# CHECK: %[[P2:.*]] = numpy.create_array_from_tensor %cst_0 : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
# CHECK: %[[R:.*]] = torch.kernel_call "aten::conv2d" %arg0, %[[P1]], %[[P2]], %0, %1, %2, %c1_i64_5 : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
# CHECK: return %[[R]] : !numpy.ndarray<[3,4,8,8]:f32>
print(mb.module)

View File

@ -6,8 +6,9 @@ import torch
import torch_mlir
import torchvision.models as models
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# TODO: Pass through npcomp-opt and FileCheck once able to elide large elements.
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
model = models.resnet18()
@ -21,8 +22,9 @@ with mb.capture_function("res18", [tensor]) as f:
result = model(tensor)
f.returns([result])
print(mb.module)
# for now we just check the output shape
# CHECK-LABEL: @res18
# TODO: Add checks once running to this point.
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)

View File

@ -6,8 +6,8 @@ import torch
import torch_mlir
import torchvision.models as models
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
model = models.vgg11_bn()
@ -23,4 +23,6 @@ with mb.capture_function("vgg11", [inputs]) as f:
# CHECK-LABEL: func @vgg11
# TODO: Add checks once passing this far.
print(mb.module)
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)

View File

@ -16,6 +16,17 @@
extern "C" {
#endif
/*============================================================================*/
/* Any dtype type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsAAnyDtype(MlirType t);
/** Gets the "any dtype" type. */
MlirType npcompAnyDtypeTypeGet(MlirContext context);
/*============================================================================*/
/* Bool type. */
/*============================================================================*/
@ -27,15 +38,24 @@ int npcompTypeIsABool(MlirType t);
MlirType npcompBoolTypeGet(MlirContext context);
/*============================================================================*/
/* Any dtype type. */
/* Dict type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsAAnyDtype(MlirType t);
/** Checks whether the given type is the Python "dict" type. */
int npcompTypeIsADict(MlirType t);
/** Gets the "any dtype" type. */
MlirType npcompAnyDtypeTypeGet(MlirContext context);
/** Gets the generic Python "dict" type. */
MlirType npcompDictTypeGet(MlirContext context);
/*============================================================================*/
/* List type. */
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsAList(MlirType t);
/** Gets the generic Python "list" type. */
MlirType npcompListTypeGet(MlirContext context);
/*============================================================================*/
/* NDArray type. */
@ -52,6 +72,27 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
/// Helper that gets an equivalent NdArrayType from a ShapedType.
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
/*============================================================================*/
/* None type. */
/*============================================================================*/
/** Checks whether the given type is the type of the singleton 'None' value. */
int npcompTypeIsANone(MlirType t);
/** Gets the type of the singleton 'None'. */
MlirType npcompNoneTypeGet(MlirContext context);
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsATuple(MlirType t);
/** Gets the generic Python "tuple" type. */
MlirType npcompTupleTypeGet(MlirContext context);
#ifdef __cplusplus
}
#endif

View File

@ -37,6 +37,13 @@ public:
static BytesType get(MLIRContext *context) { return Base::get(context); }
};
/// Python 'dict' type.
class DictType : public Type::TypeBase<DictType, Type, TypeStorage> {
public:
using Base::Base;
static DictType get(MLIRContext *context) { return Base::get(context); }
};
/// The type of the Python `Ellipsis` value.
class EllipsisType : public Type::TypeBase<EllipsisType, Type, TypeStorage> {
public:
@ -44,6 +51,13 @@ public:
static EllipsisType get(MLIRContext *context) { return Base::get(context); }
};
/// Python 'list' type.
class ListType : public Type::TypeBase<ListType, Type, TypeStorage> {
public:
using Base::Base;
static ListType get(MLIRContext *context) { return Base::get(context); }
};
/// The type of the Python `None` value.
class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
public:
@ -74,6 +88,13 @@ public:
static StrType get(MLIRContext *context) { return Base::get(context); }
};
/// Python 'tuple' type.
class TupleType : public Type::TypeBase<TupleType, Type, TypeStorage> {
public:
using Base::Base;
static TupleType get(MLIRContext *context) { return Base::get(context); }
};
/// An unknown type that could be any supported python type.
class UnknownType : public Type::TypeBase<UnknownType, Type, TypeStorage,
NPCOMPTypingTypeMapInterface::Trait> {

View File

@ -19,7 +19,7 @@ def Basicpy_Dialect : Dialect {
let name = "basicpy";
let summary = "Basic Python dialect";
let description = [{
Core types and ops
Core types and ops
}];
let cppNamespace = "::mlir::NPCOMP::Basicpy";
}
@ -31,7 +31,7 @@ def Basicpy_Dialect : Dialect {
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Basicpy_Dialect, mnemonic, traits> {
let parser = [{ return parse$cppClass(parser, &result); }];
let printer = [{ return print$cppClass(p, *this); }];
let printer = [{ return print$cppClass(p, *this); }];
}
//===----------------------------------------------------------------------===//
@ -71,7 +71,7 @@ def Basicpy_NoneType : DialectType<Basicpy_Dialect,
}
def Basicpy_SlotObjectType : DialectType<Basicpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">,
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">,
"Slot object"> {
let typeDescription = [{
Type for built-in objects which have a fixed number of slots and a type
@ -96,6 +96,47 @@ def Basicpy_UnknownType : DialectType<Basicpy_Dialect,
}];
}
def Basicpy_ListType : DialectType<Basicpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::ListType>()">,
"List type"> {
let typeDescription = [{
A Python list type. In the non-parameterized case, there are limited
constraints on the element type or length; however, it can be refined to
include such constraints.
As in Python, this list type represents a mutable, reference counted
object in a corresponding runtime layer.
}];
}
def Basicpy_TupleType : DialectType<Basicpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::TupleType>()">,
"Tuple type"> {
let typeDescription = [{
A Python tuple type. In the non-parameterized case, there are limited
constraints on the element type or length; however, it can be refined to
include such constraints.
As in Python, post-construction tuple's are immutable, reference counted
objects in a corresponding runtime layer. However, since they are
immutable, they can also serve as value-typed entities if their elements
are immutable.
}];
}
def Basicpy_DictType : DialectType<Basicpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::DictType>()">,
"Dict type"> {
let typeDescription = [{
A Python dict type. In the non-parameterized case, there are limited
constraints on the key or value types; however, it can be refined to
include such constraints.
As in Python, this list type represents a mutable, reference counted
object in a corresponding runtime layer.
}];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
@ -117,7 +158,7 @@ class Basicpy_SlotObjectOfClassArity<string className, int arity> :
// Type representing a 'slice' object, which mirrors the Python built-in
// slice class.
def Basicpy_SliceSlotObjectType :
def Basicpy_SliceSlotObjectType :
Type<Basicpy_SlotObjectOfClassArity<"slice", 3>>;
#endif // NPCOMP_DIALECT_BASICPY_IR_BASICPY_DIALECT

View File

@ -90,7 +90,7 @@ def CompareOperationAttr : StrEnumAttr<
}
//===----------------------------------------------------------------------===//
// Ops
// Operations
//===----------------------------------------------------------------------===//
def Basicpy_BinaryCompareOp : Basicpy_Op<"binary_compare", []> {
@ -154,6 +154,66 @@ def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
let hasFolder = 1;
}
// TODO: Implement ConstantLike op trait.
def Basicpy_BuildDictOp : Basicpy_Op<"build_dict", [NoSideEffect]> {
let summary = "Builds an empty dict";
let description = [{
This op mirrors the CPython BUILD_MAP op (note naming difference).
Note that as with CPython, this op only builds an empty dict; however,
it is reserved in the future for it to take variadic operands to construct
with a list of key/value pairs.
}];
let arguments = (ins
);
let results = (outs
Basicpy_DictType:$result
);
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
}
// TODO: Implement ConstantLike op trait.
def Basicpy_BuildListOp : Basicpy_Op<"build_list", [NoSideEffect]> {
let summary = "Builds a list from operands";
let description = [{
Constructs a new list object from its operands.
TODO: Any allowable type can be expressed in lists; however, this should be
revisited once more of the dialect infrastructure is in place and tightened
up accordingly. At that time, appropriate constraints should be added that
both allow correct program representation and support transformations to
lower levels (i.e. allowing a wider set of types as useful for conversions).
}];
let arguments = (ins
Variadic<AnyType>:$elements
);
let results = (outs
Basicpy_ListType:$result
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
// TODO: Implement ConstantLike op trait.
def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
let summary = "Builds a tuple from operands";
let description = [{
Constructs a new tuple object from its operands.
TODO: Any allowable type can be expressed in lists; however, this should be
revisited once more of the dialect infrastructure is in place and tightened
up accordingly. At that time, appropriate constraints should be added that
both allow correct program representation and support transformations to
lower levels (i.e. allowing a wider set of types as useful for conversions).
}];
let arguments = (ins
Variadic<AnyType>:$elements
);
let results = (outs
Basicpy_TupleType:$result
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
ConstantLike, NoSideEffect]> {
let summary = "Constant bytes value";

View File

@ -14,42 +14,96 @@
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP::Basicpy;
using namespace mlir::NPCOMP::Numpy;
/*============================================================================*/
/* Bool type. */
/*============================================================================*/
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<BoolType>(); }
MlirType npcompBoolTypeGet(MlirContext context) {
return wrap(BoolType::get(unwrap(context)));
}
using namespace mlir::NPCOMP;
/*============================================================================*/
/* Any dtype type. */
/*============================================================================*/
int npcompTypeIsAAnyDtype(MlirType t) { return unwrap(t).isa<AnyDtypeType>(); }
int npcompTypeIsAAnyDtype(MlirType t) {
return unwrap(t).isa<Numpy::AnyDtypeType>();
}
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
return wrap(AnyDtypeType::get(unwrap(context)));
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
}
/*============================================================================*/
/* Bool type. */
/*============================================================================*/
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<Basicpy::BoolType>(); }
MlirType npcompBoolTypeGet(MlirContext context) {
return wrap(Basicpy::BoolType::get(unwrap(context)));
}
/*============================================================================*/
/* Dict type. */
/*============================================================================*/
/** Checks whether the given type is the Python "dict" type. */
int npcompTypeIsADict(MlirType t) { return unwrap(t).isa<Basicpy::DictType>(); }
/** Gets the generic Python "dict" type. */
MlirType npcompDictTypeGet(MlirContext context) {
return wrap(Basicpy::DictType::get(unwrap(context)));
}
/*============================================================================*/
/* List type. */
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsAList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
/** Gets the generic Python "dict" type. */
MlirType npcompListTypeGet(MlirContext context) {
return wrap(Basicpy::ListType::get(unwrap(context)));
}
/*============================================================================*/
/* NDArray type. */
/*============================================================================*/
int npcompTypeIsANdArray(MlirType t) { return unwrap(t).isa<NdArrayType>(); }
int npcompTypeIsANdArray(MlirType t) {
return unwrap(t).isa<Numpy::NdArrayType>();
}
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType) {
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
}
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
return wrap(
NdArrayType::getFromShapedType(unwrap(shapedType).cast<ShapedType>()));
return wrap(Numpy::NdArrayType::getFromShapedType(
unwrap(shapedType).cast<ShapedType>()));
}
/*============================================================================*/
/* None type. */
/*============================================================================*/
/** Checks whether the given type is the type of the singleton 'None' value. */
int npcompTypeIsANone(MlirType t) { return unwrap(t).isa<Basicpy::NoneType>(); }
/** Gets the type of the singleton 'None'. */
MlirType npcompNoneTypeGet(MlirContext context) {
return wrap(Basicpy::NoneType::get(unwrap(context)));
}
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsATuple(MlirType t) {
return unwrap(t).isa<Basicpy::TupleType>();
}
/** Gets the "any dtype" type. */
MlirType npcompTupleTypeGet(MlirContext context) {
return wrap(Basicpy::TupleType::get(unwrap(context)));
}

View File

@ -20,8 +20,8 @@ void BasicpyDialect::initialize() {
#define GET_OP_LIST
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"
>();
addTypes<BoolType, BytesType, EllipsisType, NoneType, SlotObjectType, StrType,
UnknownType>();
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
SlotObjectType, StrType, TupleType, UnknownType>();
// TODO: Make real ops for everything we need.
allowUnknownOperations();
@ -36,8 +36,12 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
return BoolType::get(getContext());
if (keyword == "BytesType")
return BytesType::get(getContext());
if (keyword == "DictType")
return DictType::get(getContext());
if (keyword == "EllipsisType")
return EllipsisType::get(getContext());
if (keyword == "ListType")
return ListType::get(getContext());
if (keyword == "NoneType")
return NoneType::get(getContext());
if (keyword == "SlotObject") {
@ -60,6 +64,8 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
}
if (keyword == "StrType")
return StrType::get(getContext());
if (keyword == "TupleType")
return TupleType::get(getContext());
if (keyword == "UnknownType")
return UnknownType::get(getContext());
@ -71,7 +77,9 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<BoolType>([&](Type) { os << "BoolType"; })
.Case<BytesType>([&](Type) { os << "BytesType"; })
.Case<DictType>([&](Type) { os << "DictType"; })
.Case<EllipsisType>([&](Type) { os << "EllipsisType"; })
.Case<ListType>([&](Type) { os << "ListType"; })
.Case<NoneType>([&](Type) { os << "NoneType"; })
.Case<SlotObjectType>([&](SlotObjectType slotObject) {
auto slotTypes = slotObject.getSlotTypes();
@ -84,6 +92,7 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
os << ">";
})
.Case<StrType>([&](Type) { os << "StrType"; })
.Case<TupleType>([&](Type) { os << "TupleType"; })
.Case<UnknownType>([&](Type) { os << "UnknownType"; })
.Default(
[&](Type) { llvm_unreachable("unexpected 'basicpy' type kind"); });

View File

@ -0,0 +1,27 @@
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
// -----
// CHECK-LABEL: @build_dict_generic
func @build_dict_generic() -> !basicpy.DictType {
// CHECK: basicpy.build_dict : () -> !basicpy.DictType
%0 = basicpy.build_dict : () -> !basicpy.DictType
return %0 : !basicpy.DictType
}
// -----
// CHECK-LABEL: @build_list_generic
func @build_list_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.ListType {
// CHECK: basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType
%0 = basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType
return %0 : !basicpy.ListType
}
// -----
// CHECK-LABEL: @build_tuple_generic
func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType {
// CHECK: basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType
%0 = basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType
return %0 : !basicpy.TupleType
}