mirror of https://github.com/llvm/torch-mlir
Add remaining pieces to capture full example models.
* Adds Basicpy List, Tuple, Dict types and plumbs through C API. * Started debugging the issues around aten::conv2d capture, but a PyTorch bug is suspected. * Was able to manually verify that the basic conv2d forward test captures correctly with a workaround. * Need to resolve some printing issues upstream and move these tests to an integration test target (they take ~seconds to run).pull/88/head
parent
81119aa0a1
commit
029815152e
|
@ -31,7 +31,11 @@ using c10::Stack;
|
|||
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
|
||||
// enum value, so we define both. We can get rid of both once we have our
|
||||
// own key.
|
||||
#define ACAP_DISPATCH_KEY PrivateUse1
|
||||
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
|
||||
// through the autograd keys.
|
||||
// https://github.com/llvm/mlir-npcomp/issues/86
|
||||
// #define ACAP_DISPATCH_KEY AutogradPrivateUse3
|
||||
#define ACAP_DISPATCH_KEY PrivateUse3
|
||||
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
||||
|
||||
std::list<AcapController::Activation> &
|
||||
|
@ -74,8 +78,8 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
|||
// Exclude recursive dispatch in order to print tensor.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
std::stringstream msg;
|
||||
msg << "Cannot return a tensor that is not from the capture context: ";
|
||||
msg << tensor;
|
||||
msg << "Cannot return a tensor that is not from the capture context: "
|
||||
<< tensor;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
@ -85,8 +89,7 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
|||
|
||||
MlirLocation loc = getCurrentLocation();
|
||||
OperationStateHolder s("std.return", loc);
|
||||
mlirOperationStateAddOperands(&s.state, returnsValues.size(),
|
||||
returnsValues.data());
|
||||
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
|
||||
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(
|
||||
s.createOperation());
|
||||
funcBuilder->rewriteFuncReturnTypes(returnsTypes);
|
||||
|
@ -161,7 +164,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
|||
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
||||
"kernel_name",
|
||||
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
|
||||
mlirOperationStateAddAttributes(&stateHolder.state, 1, &kernelNameAttr);
|
||||
mlirOperationStateAddAttributes(stateHolder, 1, &kernelNameAttr);
|
||||
|
||||
// Map arguments to operands.
|
||||
// This must be accumulated into the OperationState prior to re-dispatch
|
||||
|
@ -179,8 +182,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
|||
}
|
||||
operands.push_back(mlirValue);
|
||||
}
|
||||
mlirOperationStateAddOperands(&stateHolder.state, operands.size(),
|
||||
operands.data());
|
||||
mlirOperationStateAddOperands(stateHolder, operands.size(), operands.data());
|
||||
|
||||
// Invoke the original kernel.
|
||||
redispatch(opHandle, stack);
|
||||
|
@ -205,7 +207,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
|||
resultIndexToTensorMap.emplace_back(resultIndex, returnIt->toTensor());
|
||||
}
|
||||
}
|
||||
mlirOperationStateAddResults(&stateHolder.state, resultTypes.size(),
|
||||
mlirOperationStateAddResults(stateHolder, resultTypes.size(),
|
||||
resultTypes.data());
|
||||
|
||||
// Create operation.
|
||||
|
@ -244,9 +246,19 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||
return funcBuilder->getBoolConstant(loc, ival.toBool());
|
||||
}
|
||||
if (ival.isList()) {
|
||||
auto list = ival.toList();
|
||||
llvm::SmallVector<MlirValue, 4> elements;
|
||||
for (c10::IValue element : list) {
|
||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||
}
|
||||
return funcBuilder->buildConstantList(loc, elements);
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
return {nullptr};
|
||||
// TODO: Implement mappings for the whole set (relevant to this use case):
|
||||
// _(None)
|
||||
// _(Tensor)
|
||||
// _(Double)
|
||||
// _(Int)
|
||||
|
@ -277,6 +289,12 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
}
|
||||
if (ival.isList()) {
|
||||
return npcompListTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
|
@ -357,3 +375,8 @@ TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
|||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||
&AcapController::fallbackKernel>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
|
||||
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<
|
||||
&AcapController::fallbackKernel>());
|
||||
}
|
||||
|
|
|
@ -17,8 +17,8 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
|||
MlirAttribute value) {
|
||||
OperationStateHolder s("std.constant", loc);
|
||||
MlirNamedAttribute valueAttr = mlirNamedAttributeGet("value", value);
|
||||
mlirOperationStateAddResults(&s.state, 1, &type);
|
||||
mlirOperationStateAddAttributes(&s.state, 1, &valueAttr);
|
||||
mlirOperationStateAddResults(s, 1, &type);
|
||||
mlirOperationStateAddAttributes(s, 1, &valueAttr);
|
||||
return s.createOperation();
|
||||
}
|
||||
|
||||
|
@ -170,6 +170,14 @@ MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
|
|||
return getGeneralConstant(loc, value);
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
|
||||
OperationStateHolder state{"basicpy.singleton", loc};
|
||||
MlirType noneType = npcompNoneTypeGet(context);
|
||||
mlirOperationStateAddResults(state, 1, &noneType);
|
||||
MlirOperation op = state.createOperation();
|
||||
return insertConstantOp(op);
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
||||
MlirAttribute value) {
|
||||
MlirType valueType = mlirAttributeGetType(value);
|
||||
|
@ -177,3 +185,14 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
MlirValue constValue = insertConstantOp(constOp);
|
||||
return constValue;
|
||||
}
|
||||
|
||||
MlirValue
|
||||
FuncBuilder::buildConstantList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(context);
|
||||
OperationStateHolder state{"basicpy.build_list", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
MlirOperation op = state.createOperation();
|
||||
return insertConstantOp(op);
|
||||
}
|
||||
|
|
|
@ -31,15 +31,16 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
operator MlirOperationState *() { return &state; }
|
||||
|
||||
MlirOperation createOperation() {
|
||||
assert(owned && "cannot createOperation on unowned state");
|
||||
owned = false;
|
||||
return mlirOperationCreate(&state);
|
||||
}
|
||||
|
||||
MlirOperationState state;
|
||||
|
||||
private:
|
||||
MlirOperationState state;
|
||||
bool owned = true;
|
||||
};
|
||||
|
||||
|
@ -123,10 +124,19 @@ public:
|
|||
/// Gets a bool constant value.
|
||||
MlirValue getBoolConstant(MlirLocation loc, bool v);
|
||||
|
||||
/// Gets a None constant value.
|
||||
MlirValue getNoneConstant(MlirLocation loc);
|
||||
|
||||
/// Gets a general constant value representing the given value
|
||||
/// attribute.
|
||||
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||
|
||||
/// Builds a list with the given elements (derived from constants).
|
||||
/// The resulting list is inserted into the "constant section" of the
|
||||
/// function.
|
||||
MlirValue buildConstantList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements);
|
||||
|
||||
private:
|
||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||
BlockBuilder entryBlock)
|
||||
|
|
|
@ -11,8 +11,8 @@ import torch.nn.functional as F
|
|||
|
||||
import torch_mlir
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/79
|
||||
# XFAIL: *
|
||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
class ResA(nn.Module):
|
||||
|
@ -54,4 +54,6 @@ with mb.capture_function("resa", [inputs]) as f:
|
|||
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
|
||||
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
|
||||
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
|
||||
print(mb.module)
|
||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
||||
# print(mb.module)
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
|
||||
ref_model.weight.data = model.weight.clone()
|
||||
ref_model.bias.data = model.bias.clone()
|
||||
|
||||
softmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.nn.NLLLoss()
|
||||
|
||||
tensor = torch.randn(N, Cin, h, w)
|
||||
|
||||
with mb.capture_function("@conv2d_fwd", [tensor]) as f:
|
||||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
ref_target = target.clone()
|
||||
|
||||
with mb.capture_function("@conv2d_backward", [result, target]) as f:
|
||||
test_loss = loss(softmax(result), target)
|
||||
f.returns([test_loss.backward()])
|
||||
|
||||
# CHECK-LABEL: func @conv2d_fwd
|
||||
# TODO: Add checks when passing
|
||||
|
||||
# CHECK-LABEL: func @conv2d_backward
|
||||
# TODO: Update checks when passing
|
||||
# NO-CHECK: aten.convolution_overrideable
|
||||
# NO-CHECK: aten._log_softmax
|
||||
# NO-CHECK: aten.nll_loss2d_forward
|
||||
print(mb.module)
|
|
@ -0,0 +1,41 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# XFAIL: *
|
||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
|
||||
ref_model.weight.data = model.weight.clone()
|
||||
ref_model.bias.data = model.bias.clone()
|
||||
|
||||
softmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.nn.NLLLoss()
|
||||
|
||||
tensor = torch.randn(N, Cin, h, w)
|
||||
|
||||
with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
||||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK-LABEL: func @conv2d_fwd
|
||||
# CHECK-SAME: (%arg0: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||
# CHECK: %[[P1:.*]] = numpy.create_array_from_tensor %cst : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||
# CHECK: %[[P2:.*]] = numpy.create_array_from_tensor %cst_0 : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||
# CHECK: %[[R:.*]] = torch.kernel_call "aten::conv2d" %arg0, %[[P1]], %[[P2]], %0, %1, %2, %c1_i64_5 : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
||||
# CHECK: return %[[R]] : !numpy.ndarray<[3,4,8,8]:f32>
|
||||
print(mb.module)
|
|
@ -6,8 +6,9 @@ import torch
|
|||
import torch_mlir
|
||||
import torchvision.models as models
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
||||
# TODO: Pass through npcomp-opt and FileCheck once able to elide large elements.
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
model = models.resnet18()
|
||||
|
@ -21,8 +22,9 @@ with mb.capture_function("res18", [tensor]) as f:
|
|||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
print(mb.module)
|
||||
|
||||
# for now we just check the output shape
|
||||
# CHECK-LABEL: @res18
|
||||
# TODO: Add checks once running to this point.
|
||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
||||
# print(mb.module)
|
||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
|||
import torch_mlir
|
||||
import torchvision.models as models
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
model = models.vgg11_bn()
|
||||
|
@ -23,4 +23,6 @@ with mb.capture_function("vgg11", [inputs]) as f:
|
|||
|
||||
# CHECK-LABEL: func @vgg11
|
||||
# TODO: Add checks once passing this far.
|
||||
print(mb.module)
|
||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
||||
# print(mb.module)
|
||||
|
|
|
@ -16,6 +16,17 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsAAnyDtype(MlirType t);
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
@ -27,15 +38,24 @@ int npcompTypeIsABool(MlirType t);
|
|||
MlirType npcompBoolTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/* Dict type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsAAnyDtype(MlirType t);
|
||||
/** Checks whether the given type is the Python "dict" type. */
|
||||
int npcompTypeIsADict(MlirType t);
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompDictTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* List type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsAList(MlirType t);
|
||||
|
||||
/** Gets the generic Python "list" type. */
|
||||
MlirType npcompListTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
|
@ -52,6 +72,27 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
|||
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* None type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the type of the singleton 'None' value. */
|
||||
int npcompTypeIsANone(MlirType t);
|
||||
|
||||
/** Gets the type of the singleton 'None'. */
|
||||
MlirType npcompNoneTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsATuple(MlirType t);
|
||||
|
||||
/** Gets the generic Python "tuple" type. */
|
||||
MlirType npcompTupleTypeGet(MlirContext context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -37,6 +37,13 @@ public:
|
|||
static BytesType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// Python 'dict' type.
|
||||
class DictType : public Type::TypeBase<DictType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static DictType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// The type of the Python `Ellipsis` value.
|
||||
class EllipsisType : public Type::TypeBase<EllipsisType, Type, TypeStorage> {
|
||||
public:
|
||||
|
@ -44,6 +51,13 @@ public:
|
|||
static EllipsisType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// Python 'list' type.
|
||||
class ListType : public Type::TypeBase<ListType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static ListType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// The type of the Python `None` value.
|
||||
class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
|
||||
public:
|
||||
|
@ -74,6 +88,13 @@ public:
|
|||
static StrType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// Python 'tuple' type.
|
||||
class TupleType : public Type::TypeBase<TupleType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static TupleType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
/// An unknown type that could be any supported python type.
|
||||
class UnknownType : public Type::TypeBase<UnknownType, Type, TypeStorage,
|
||||
NPCOMPTypingTypeMapInterface::Trait> {
|
||||
|
|
|
@ -19,7 +19,7 @@ def Basicpy_Dialect : Dialect {
|
|||
let name = "basicpy";
|
||||
let summary = "Basic Python dialect";
|
||||
let description = [{
|
||||
Core types and ops
|
||||
Core types and ops
|
||||
}];
|
||||
let cppNamespace = "::mlir::NPCOMP::Basicpy";
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ def Basicpy_Dialect : Dialect {
|
|||
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Basicpy_Dialect, mnemonic, traits> {
|
||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -71,7 +71,7 @@ def Basicpy_NoneType : DialectType<Basicpy_Dialect,
|
|||
}
|
||||
|
||||
def Basicpy_SlotObjectType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">,
|
||||
"Slot object"> {
|
||||
let typeDescription = [{
|
||||
Type for built-in objects which have a fixed number of slots and a type
|
||||
|
@ -96,6 +96,47 @@ def Basicpy_UnknownType : DialectType<Basicpy_Dialect,
|
|||
}];
|
||||
}
|
||||
|
||||
def Basicpy_ListType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::ListType>()">,
|
||||
"List type"> {
|
||||
let typeDescription = [{
|
||||
A Python list type. In the non-parameterized case, there are limited
|
||||
constraints on the element type or length; however, it can be refined to
|
||||
include such constraints.
|
||||
|
||||
As in Python, this list type represents a mutable, reference counted
|
||||
object in a corresponding runtime layer.
|
||||
}];
|
||||
}
|
||||
|
||||
def Basicpy_TupleType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::TupleType>()">,
|
||||
"Tuple type"> {
|
||||
let typeDescription = [{
|
||||
A Python tuple type. In the non-parameterized case, there are limited
|
||||
constraints on the element type or length; however, it can be refined to
|
||||
include such constraints.
|
||||
|
||||
As in Python, post-construction tuple's are immutable, reference counted
|
||||
objects in a corresponding runtime layer. However, since they are
|
||||
immutable, they can also serve as value-typed entities if their elements
|
||||
are immutable.
|
||||
}];
|
||||
}
|
||||
|
||||
def Basicpy_DictType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::DictType>()">,
|
||||
"Dict type"> {
|
||||
let typeDescription = [{
|
||||
A Python dict type. In the non-parameterized case, there are limited
|
||||
constraints on the key or value types; however, it can be refined to
|
||||
include such constraints.
|
||||
|
||||
As in Python, this list type represents a mutable, reference counted
|
||||
object in a corresponding runtime layer.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -117,7 +158,7 @@ class Basicpy_SlotObjectOfClassArity<string className, int arity> :
|
|||
|
||||
// Type representing a 'slice' object, which mirrors the Python built-in
|
||||
// slice class.
|
||||
def Basicpy_SliceSlotObjectType :
|
||||
def Basicpy_SliceSlotObjectType :
|
||||
Type<Basicpy_SlotObjectOfClassArity<"slice", 3>>;
|
||||
|
||||
#endif // NPCOMP_DIALECT_BASICPY_IR_BASICPY_DIALECT
|
||||
|
|
|
@ -90,7 +90,7 @@ def CompareOperationAttr : StrEnumAttr<
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops
|
||||
// Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_BinaryCompareOp : Basicpy_Op<"binary_compare", []> {
|
||||
|
@ -154,6 +154,66 @@ def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO: Implement ConstantLike op trait.
|
||||
def Basicpy_BuildDictOp : Basicpy_Op<"build_dict", [NoSideEffect]> {
|
||||
let summary = "Builds an empty dict";
|
||||
let description = [{
|
||||
This op mirrors the CPython BUILD_MAP op (note naming difference).
|
||||
|
||||
Note that as with CPython, this op only builds an empty dict; however,
|
||||
it is reserved in the future for it to take variadic operands to construct
|
||||
with a list of key/value pairs.
|
||||
}];
|
||||
let arguments = (ins
|
||||
);
|
||||
let results = (outs
|
||||
Basicpy_DictType:$result
|
||||
);
|
||||
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
// TODO: Implement ConstantLike op trait.
|
||||
def Basicpy_BuildListOp : Basicpy_Op<"build_list", [NoSideEffect]> {
|
||||
let summary = "Builds a list from operands";
|
||||
let description = [{
|
||||
Constructs a new list object from its operands.
|
||||
|
||||
TODO: Any allowable type can be expressed in lists; however, this should be
|
||||
revisited once more of the dialect infrastructure is in place and tightened
|
||||
up accordingly. At that time, appropriate constraints should be added that
|
||||
both allow correct program representation and support transformations to
|
||||
lower levels (i.e. allowing a wider set of types as useful for conversions).
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$elements
|
||||
);
|
||||
let results = (outs
|
||||
Basicpy_ListType:$result
|
||||
);
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
// TODO: Implement ConstantLike op trait.
|
||||
def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
|
||||
let summary = "Builds a tuple from operands";
|
||||
let description = [{
|
||||
Constructs a new tuple object from its operands.
|
||||
|
||||
TODO: Any allowable type can be expressed in lists; however, this should be
|
||||
revisited once more of the dialect infrastructure is in place and tightened
|
||||
up accordingly. At that time, appropriate constraints should be added that
|
||||
both allow correct program representation and support transformations to
|
||||
lower levels (i.e. allowing a wider set of types as useful for conversions).
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$elements
|
||||
);
|
||||
let results = (outs
|
||||
Basicpy_TupleType:$result
|
||||
);
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
let summary = "Constant bytes value";
|
||||
|
|
|
@ -14,42 +14,96 @@
|
|||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::Basicpy;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<BoolType>(); }
|
||||
|
||||
MlirType npcompBoolTypeGet(MlirContext context) {
|
||||
return wrap(BoolType::get(unwrap(context)));
|
||||
}
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsAAnyDtype(MlirType t) { return unwrap(t).isa<AnyDtypeType>(); }
|
||||
int npcompTypeIsAAnyDtype(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::AnyDtypeType>();
|
||||
}
|
||||
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
||||
return wrap(AnyDtypeType::get(unwrap(context)));
|
||||
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<Basicpy::BoolType>(); }
|
||||
|
||||
MlirType npcompBoolTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::BoolType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Dict type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "dict" type. */
|
||||
int npcompTypeIsADict(MlirType t) { return unwrap(t).isa<Basicpy::DictType>(); }
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompDictTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::DictType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* List type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsAList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompListTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::ListType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsANdArray(MlirType t) { return unwrap(t).isa<NdArrayType>(); }
|
||||
int npcompTypeIsANdArray(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::NdArrayType>();
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType) {
|
||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||
return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
|
||||
return wrap(
|
||||
NdArrayType::getFromShapedType(unwrap(shapedType).cast<ShapedType>()));
|
||||
return wrap(Numpy::NdArrayType::getFromShapedType(
|
||||
unwrap(shapedType).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* None type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the type of the singleton 'None' value. */
|
||||
int npcompTypeIsANone(MlirType t) { return unwrap(t).isa<Basicpy::NoneType>(); }
|
||||
|
||||
/** Gets the type of the singleton 'None'. */
|
||||
MlirType npcompNoneTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsATuple(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::TupleType>();
|
||||
}
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompTupleTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
||||
}
|
||||
|
|
|
@ -20,8 +20,8 @@ void BasicpyDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"
|
||||
>();
|
||||
addTypes<BoolType, BytesType, EllipsisType, NoneType, SlotObjectType, StrType,
|
||||
UnknownType>();
|
||||
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
|
||||
SlotObjectType, StrType, TupleType, UnknownType>();
|
||||
|
||||
// TODO: Make real ops for everything we need.
|
||||
allowUnknownOperations();
|
||||
|
@ -36,8 +36,12 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
return BoolType::get(getContext());
|
||||
if (keyword == "BytesType")
|
||||
return BytesType::get(getContext());
|
||||
if (keyword == "DictType")
|
||||
return DictType::get(getContext());
|
||||
if (keyword == "EllipsisType")
|
||||
return EllipsisType::get(getContext());
|
||||
if (keyword == "ListType")
|
||||
return ListType::get(getContext());
|
||||
if (keyword == "NoneType")
|
||||
return NoneType::get(getContext());
|
||||
if (keyword == "SlotObject") {
|
||||
|
@ -60,6 +64,8 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
}
|
||||
if (keyword == "StrType")
|
||||
return StrType::get(getContext());
|
||||
if (keyword == "TupleType")
|
||||
return TupleType::get(getContext());
|
||||
if (keyword == "UnknownType")
|
||||
return UnknownType::get(getContext());
|
||||
|
||||
|
@ -71,7 +77,9 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
TypeSwitch<Type>(type)
|
||||
.Case<BoolType>([&](Type) { os << "BoolType"; })
|
||||
.Case<BytesType>([&](Type) { os << "BytesType"; })
|
||||
.Case<DictType>([&](Type) { os << "DictType"; })
|
||||
.Case<EllipsisType>([&](Type) { os << "EllipsisType"; })
|
||||
.Case<ListType>([&](Type) { os << "ListType"; })
|
||||
.Case<NoneType>([&](Type) { os << "NoneType"; })
|
||||
.Case<SlotObjectType>([&](SlotObjectType slotObject) {
|
||||
auto slotTypes = slotObject.getSlotTypes();
|
||||
|
@ -84,6 +92,7 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
os << ">";
|
||||
})
|
||||
.Case<StrType>([&](Type) { os << "StrType"; })
|
||||
.Case<TupleType>([&](Type) { os << "TupleType"; })
|
||||
.Case<UnknownType>([&](Type) { os << "UnknownType"; })
|
||||
.Default(
|
||||
[&](Type) { llvm_unreachable("unexpected 'basicpy' type kind"); });
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @build_dict_generic
|
||||
func @build_dict_generic() -> !basicpy.DictType {
|
||||
// CHECK: basicpy.build_dict : () -> !basicpy.DictType
|
||||
%0 = basicpy.build_dict : () -> !basicpy.DictType
|
||||
return %0 : !basicpy.DictType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @build_list_generic
|
||||
func @build_list_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.ListType {
|
||||
// CHECK: basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType
|
||||
%0 = basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @build_tuple_generic
|
||||
func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType {
|
||||
// CHECK: basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType
|
||||
%0 = basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType
|
||||
return %0 : !basicpy.TupleType
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue