From 029815152e6a072e3a7f8d817accaae18d6877e1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 16 Oct 2020 17:38:07 -0700 Subject: [PATCH] Add remaining pieces to capture full example models. * Adds Basicpy List, Tuple, Dict types and plumbs through C API. * Started debugging the issues around aten::conv2d capture, but a PyTorch bug is suspected. * Was able to manually verify that the basic conv2d forward test captures correctly with a workaround. * Need to resolve some printing issues upstream and move these tests to an integration test target (they take ~seconds to run). --- .../csrc/c10_dispatch/acap_dispatch.cpp | 43 ++++++--- .../csrc/c10_dispatch/func_builder.cpp | 23 ++++- .../pytorch/csrc/c10_dispatch/func_builder.h | 14 ++- .../test/acap_export/test_export_ResA.py | 6 +- .../acap_export/test_export_conv2d_back.py | 50 ----------- .../acap_export/test_export_conv2d_fwd.py | 41 +++++++++ .../test/acap_export/test_export_resnet18.py | 8 +- .../test/acap_export/test_export_vgg11.py | 6 +- include/npcomp-c/Types.h | 53 +++++++++-- .../Dialect/Basicpy/IR/BasicpyDialect.h | 21 +++++ .../Dialect/Basicpy/IR/BasicpyDialect.td | 49 +++++++++- .../npcomp/Dialect/Basicpy/IR/BasicpyOps.td | 62 ++++++++++++- lib/CAPI/Types.cpp | 90 +++++++++++++++---- lib/Dialect/Basicpy/IR/BasicpyDialect.cpp | 13 ++- test/Dialect/Basicpy/ops.mlir | 27 ++++++ 15 files changed, 404 insertions(+), 102 deletions(-) delete mode 100644 frontends/pytorch/test/acap_export/test_export_conv2d_back.py create mode 100644 frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py create mode 100644 test/Dialect/Basicpy/ops.mlir diff --git a/frontends/pytorch/csrc/c10_dispatch/acap_dispatch.cpp b/frontends/pytorch/csrc/c10_dispatch/acap_dispatch.cpp index 5035db8d7..093566e27 100644 --- a/frontends/pytorch/csrc/c10_dispatch/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/c10_dispatch/acap_dispatch.cpp @@ -31,7 +31,11 @@ using c10::Stack; // that the TORCH_LIBRARY_* macros expand this by name and other APIs use its // enum value, so we define both. We can get rid of both once we have our // own key. -#define ACAP_DISPATCH_KEY PrivateUse1 +// TODO: Ask the PT devs why conv is special and only shows up if dispatching +// through the autograd keys. +// https://github.com/llvm/mlir-npcomp/issues/86 +// #define ACAP_DISPATCH_KEY AutogradPrivateUse3 +#define ACAP_DISPATCH_KEY PrivateUse3 static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY; std::list & @@ -74,8 +78,8 @@ void AcapController::returns(std::vector tensors) { // Exclude recursive dispatch in order to print tensor. c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey); std::stringstream msg; - msg << "Cannot return a tensor that is not from the capture context: "; - msg << tensor; + msg << "Cannot return a tensor that is not from the capture context: " + << tensor; throw std::invalid_argument(msg.str()); } @@ -85,8 +89,7 @@ void AcapController::returns(std::vector tensors) { MlirLocation loc = getCurrentLocation(); OperationStateHolder s("std.return", loc); - mlirOperationStateAddOperands(&s.state, returnsValues.size(), - returnsValues.data()); + mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data()); funcBuilder->getEntryBlockBuilder().insertBeforeTerminator( s.createOperation()); funcBuilder->rewriteFuncReturnTypes(returnsTypes); @@ -161,7 +164,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle, MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet( "kernel_name", mlirStringAttrGet(context, kernelName.size(), kernelName.data())); - mlirOperationStateAddAttributes(&stateHolder.state, 1, &kernelNameAttr); + mlirOperationStateAddAttributes(stateHolder, 1, &kernelNameAttr); // Map arguments to operands. // This must be accumulated into the OperationState prior to re-dispatch @@ -179,8 +182,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle, } operands.push_back(mlirValue); } - mlirOperationStateAddOperands(&stateHolder.state, operands.size(), - operands.data()); + mlirOperationStateAddOperands(stateHolder, operands.size(), operands.data()); // Invoke the original kernel. redispatch(opHandle, stack); @@ -205,7 +207,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle, resultIndexToTensorMap.emplace_back(resultIndex, returnIt->toTensor()); } } - mlirOperationStateAddResults(&stateHolder.state, resultTypes.size(), + mlirOperationStateAddResults(stateHolder, resultTypes.size(), resultTypes.data()); // Create operation. @@ -244,9 +246,19 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc, // TODO: Switch to the numpy.bool type as that is a closer domain match. return funcBuilder->getBoolConstant(loc, ival.toBool()); } + if (ival.isList()) { + auto list = ival.toList(); + llvm::SmallVector elements; + for (c10::IValue element : list) { + elements.push_back(mapIValueToMlirValue(loc, element)); + } + return funcBuilder->buildConstantList(loc, elements); + } + if (ival.isNone()) { + return funcBuilder->getNoneConstant(loc); + } return {nullptr}; // TODO: Implement mappings for the whole set (relevant to this use case): - // _(None) // _(Tensor) // _(Double) // _(Int) @@ -277,6 +289,12 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc, // TODO: Switch to the numpy.bool type as that is a closer domain match. return mlirIntegerTypeGet(funcBuilder->getContext(), 1); } + if (ival.isList()) { + return npcompListTypeGet(funcBuilder->getContext()); + } + if (ival.isNone()) { + return npcompNoneTypeGet(funcBuilder->getContext()); + } return {nullptr}; } @@ -357,3 +375,8 @@ TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction< &AcapController::fallbackKernel>()); } + +TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) { + m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction< + &AcapController::fallbackKernel>()); +} diff --git a/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp b/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp index c3d97cd9a..5ea05e456 100644 --- a/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp +++ b/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp @@ -17,8 +17,8 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type, MlirAttribute value) { OperationStateHolder s("std.constant", loc); MlirNamedAttribute valueAttr = mlirNamedAttributeGet("value", value); - mlirOperationStateAddResults(&s.state, 1, &type); - mlirOperationStateAddAttributes(&s.state, 1, &valueAttr); + mlirOperationStateAddResults(s, 1, &type); + mlirOperationStateAddAttributes(s, 1, &valueAttr); return s.createOperation(); } @@ -170,6 +170,14 @@ MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) { return getGeneralConstant(loc, value); } +MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) { + OperationStateHolder state{"basicpy.singleton", loc}; + MlirType noneType = npcompNoneTypeGet(context); + mlirOperationStateAddResults(state, 1, &noneType); + MlirOperation op = state.createOperation(); + return insertConstantOp(op); +} + MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, MlirAttribute value) { MlirType valueType = mlirAttributeGetType(value); @@ -177,3 +185,14 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, MlirValue constValue = insertConstantOp(constOp); return constValue; } + +MlirValue +FuncBuilder::buildConstantList(MlirLocation loc, + llvm::SmallVectorImpl &elements) { + MlirType resultType = npcompListTypeGet(context); + OperationStateHolder state{"basicpy.build_list", loc}; + mlirOperationStateAddResults(state, 1, &resultType); + mlirOperationStateAddOperands(state, elements.size(), elements.data()); + MlirOperation op = state.createOperation(); + return insertConstantOp(op); +} diff --git a/frontends/pytorch/csrc/c10_dispatch/func_builder.h b/frontends/pytorch/csrc/c10_dispatch/func_builder.h index 70026901b..09c045a89 100644 --- a/frontends/pytorch/csrc/c10_dispatch/func_builder.h +++ b/frontends/pytorch/csrc/c10_dispatch/func_builder.h @@ -31,15 +31,16 @@ public: } } + operator MlirOperationState *() { return &state; } + MlirOperation createOperation() { assert(owned && "cannot createOperation on unowned state"); owned = false; return mlirOperationCreate(&state); } - MlirOperationState state; - private: + MlirOperationState state; bool owned = true; }; @@ -123,10 +124,19 @@ public: /// Gets a bool constant value. MlirValue getBoolConstant(MlirLocation loc, bool v); + /// Gets a None constant value. + MlirValue getNoneConstant(MlirLocation loc); + /// Gets a general constant value representing the given value /// attribute. MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value); + /// Builds a list with the given elements (derived from constants). + /// The resulting list is inserted into the "constant section" of the + /// function. + MlirValue buildConstantList(MlirLocation loc, + llvm::SmallVectorImpl &elements); + private: FuncBuilder(MlirContext context, MlirOperation funcOp, BlockBuilder entryBlock) diff --git a/frontends/pytorch/test/acap_export/test_export_ResA.py b/frontends/pytorch/test/acap_export/test_export_ResA.py index 936a59848..a92a186c2 100644 --- a/frontends/pytorch/test/acap_export/test_export_ResA.py +++ b/frontends/pytorch/test/acap_export/test_export_ResA.py @@ -11,8 +11,8 @@ import torch.nn.functional as F import torch_mlir -# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/79 # XFAIL: * +# TODO: https://github.com/llvm/mlir-npcomp/issues/86 # RUN: %PYTHON %s | npcomp-opt | FileCheck %s class ResA(nn.Module): @@ -54,4 +54,6 @@ with mb.capture_function("resa", [inputs]) as f: # CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"} # CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"} # CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"} -print(mb.module) +# TODO: Enable printing once large elements can be elided (crashes lit). +# https://github.com/llvm/mlir-npcomp/issues/87 +# print(mb.module) diff --git a/frontends/pytorch/test/acap_export/test_export_conv2d_back.py b/frontends/pytorch/test/acap_export/test_export_conv2d_back.py deleted file mode 100644 index 9da94794b..000000000 --- a/frontends/pytorch/test/acap_export/test_export_conv2d_back.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- Python -*- -# This file is licensed under a pytorch-style license -# See frontends/pytorch/LICENSE for license information. - -import torch -import torch_mlir - -# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80 -# XFAIL: * -# RUN: %PYTHON %s | npcomp-opt | FileCheck %s - -mb = torch_mlir.ModuleBuilder() - -N = 3 -Cin = 16 -Cout = 4 -w = 10 -h = 10 - -model = torch.nn.Conv2d(Cin, Cout, (3,3)) -ref_model = torch.nn.Conv2d(Cin, Cout, (3,3)) - -ref_model.weight.data = model.weight.clone() -ref_model.bias.data = model.bias.clone() - -softmax = torch.nn.LogSoftmax(dim=1) -loss = torch.nn.NLLLoss() - -tensor = torch.randn(N, Cin, h, w) - -with mb.capture_function("@conv2d_fwd", [tensor]) as f: - result = model(tensor) - f.returns([result]) - -target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout) -ref_target = target.clone() - -with mb.capture_function("@conv2d_backward", [result, target]) as f: - test_loss = loss(softmax(result), target) - f.returns([test_loss.backward()]) - -# CHECK-LABEL: func @conv2d_fwd -# TODO: Add checks when passing - -# CHECK-LABEL: func @conv2d_backward -# TODO: Update checks when passing -# NO-CHECK: aten.convolution_overrideable -# NO-CHECK: aten._log_softmax -# NO-CHECK: aten.nll_loss2d_forward -print(mb.module) diff --git a/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py b/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py new file mode 100644 index 000000000..bb2f595af --- /dev/null +++ b/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py @@ -0,0 +1,41 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir + +# XFAIL: * +# TODO: https://github.com/llvm/mlir-npcomp/issues/86 +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + +N = 3 +Cin = 16 +Cout = 4 +w = 10 +h = 10 + +model = torch.nn.Conv2d(Cin, Cout, (3,3)) +ref_model = torch.nn.Conv2d(Cin, Cout, (3,3)) + +ref_model.weight.data = model.weight.clone() +ref_model.bias.data = model.bias.clone() + +softmax = torch.nn.LogSoftmax(dim=1) +loss = torch.nn.NLLLoss() + +tensor = torch.randn(N, Cin, h, w) + +with mb.capture_function("conv2d_fwd", [tensor]) as f: + result = model(tensor) + f.returns([result]) + +# CHECK-LABEL: func @conv2d_fwd +# CHECK-SAME: (%arg0: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> { +# CHECK: %[[P1:.*]] = numpy.create_array_from_tensor %cst : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32> +# CHECK: %[[P2:.*]] = numpy.create_array_from_tensor %cst_0 : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32> +# CHECK: %[[R:.*]] = torch.kernel_call "aten::conv2d" %arg0, %[[P1]], %[[P2]], %0, %1, %2, %c1_i64_5 : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> +# CHECK: return %[[R]] : !numpy.ndarray<[3,4,8,8]:f32> +print(mb.module) diff --git a/frontends/pytorch/test/acap_export/test_export_resnet18.py b/frontends/pytorch/test/acap_export/test_export_resnet18.py index b696c2275..8079ba56d 100644 --- a/frontends/pytorch/test/acap_export/test_export_resnet18.py +++ b/frontends/pytorch/test/acap_export/test_export_resnet18.py @@ -6,8 +6,9 @@ import torch import torch_mlir import torchvision.models as models -# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80 # XFAIL: * +# TODO: https://github.com/llvm/mlir-npcomp/issues/86 +# TODO: Pass through npcomp-opt and FileCheck once able to elide large elements. # RUN: %PYTHON %s | npcomp-opt | FileCheck %s model = models.resnet18() @@ -21,8 +22,9 @@ with mb.capture_function("res18", [tensor]) as f: result = model(tensor) f.returns([result]) -print(mb.module) - # for now we just check the output shape # CHECK-LABEL: @res18 # TODO: Add checks once running to this point. +# TODO: Enable printing once large elements can be elided (crashes lit). +# https://github.com/llvm/mlir-npcomp/issues/87 +# print(mb.module) diff --git a/frontends/pytorch/test/acap_export/test_export_vgg11.py b/frontends/pytorch/test/acap_export/test_export_vgg11.py index 49bd5dda9..234789357 100644 --- a/frontends/pytorch/test/acap_export/test_export_vgg11.py +++ b/frontends/pytorch/test/acap_export/test_export_vgg11.py @@ -6,8 +6,8 @@ import torch import torch_mlir import torchvision.models as models -# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80 # XFAIL: * +# TODO: https://github.com/llvm/mlir-npcomp/issues/86 # RUN: %PYTHON %s | npcomp-opt | FileCheck %s model = models.vgg11_bn() @@ -23,4 +23,6 @@ with mb.capture_function("vgg11", [inputs]) as f: # CHECK-LABEL: func @vgg11 # TODO: Add checks once passing this far. -print(mb.module) +# TODO: Enable printing once large elements can be elided (crashes lit). +# https://github.com/llvm/mlir-npcomp/issues/87 +# print(mb.module) diff --git a/include/npcomp-c/Types.h b/include/npcomp-c/Types.h index 807741f89..3508e3509 100644 --- a/include/npcomp-c/Types.h +++ b/include/npcomp-c/Types.h @@ -16,6 +16,17 @@ extern "C" { #endif +/*============================================================================*/ +/* Any dtype type. */ +/*============================================================================*/ + +/** Checks whether the given type is the special "any dtype" type that is used + * to signal an NDArray or tensor of unknown type. */ +int npcompTypeIsAAnyDtype(MlirType t); + +/** Gets the "any dtype" type. */ +MlirType npcompAnyDtypeTypeGet(MlirContext context); + /*============================================================================*/ /* Bool type. */ /*============================================================================*/ @@ -27,15 +38,24 @@ int npcompTypeIsABool(MlirType t); MlirType npcompBoolTypeGet(MlirContext context); /*============================================================================*/ -/* Any dtype type. */ +/* Dict type. */ /*============================================================================*/ -/** Checks whether the given type is the special "any dtype" type that is used - * to signal an NDArray or tensor of unknown type. */ -int npcompTypeIsAAnyDtype(MlirType t); +/** Checks whether the given type is the Python "dict" type. */ +int npcompTypeIsADict(MlirType t); -/** Gets the "any dtype" type. */ -MlirType npcompAnyDtypeTypeGet(MlirContext context); +/** Gets the generic Python "dict" type. */ +MlirType npcompDictTypeGet(MlirContext context); + +/*============================================================================*/ +/* List type. */ +/*============================================================================*/ + +/** Checks whether the given type is the Python "list" type. */ +int npcompTypeIsAList(MlirType t); + +/** Gets the generic Python "list" type. */ +MlirType npcompListTypeGet(MlirContext context); /*============================================================================*/ /* NDArray type. */ @@ -52,6 +72,27 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, /// Helper that gets an equivalent NdArrayType from a ShapedType. MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType); +/*============================================================================*/ +/* None type. */ +/*============================================================================*/ + +/** Checks whether the given type is the type of the singleton 'None' value. */ +int npcompTypeIsANone(MlirType t); + +/** Gets the type of the singleton 'None'. */ +MlirType npcompNoneTypeGet(MlirContext context); + +/*============================================================================*/ +/* Tuple type. */ +/*============================================================================*/ + +/** Checks whether the given type is the special "any dtype" type that is used + * to signal an NDArray or tensor of unknown type. */ +int npcompTypeIsATuple(MlirType t); + +/** Gets the generic Python "tuple" type. */ +MlirType npcompTupleTypeGet(MlirContext context); + #ifdef __cplusplus } #endif diff --git a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.h b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.h index bec9b8249..bf11fafed 100644 --- a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.h +++ b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.h @@ -37,6 +37,13 @@ public: static BytesType get(MLIRContext *context) { return Base::get(context); } }; +/// Python 'dict' type. +class DictType : public Type::TypeBase { +public: + using Base::Base; + static DictType get(MLIRContext *context) { return Base::get(context); } +}; + /// The type of the Python `Ellipsis` value. class EllipsisType : public Type::TypeBase { public: @@ -44,6 +51,13 @@ public: static EllipsisType get(MLIRContext *context) { return Base::get(context); } }; +/// Python 'list' type. +class ListType : public Type::TypeBase { +public: + using Base::Base; + static ListType get(MLIRContext *context) { return Base::get(context); } +}; + /// The type of the Python `None` value. class NoneType : public Type::TypeBase { public: @@ -74,6 +88,13 @@ public: static StrType get(MLIRContext *context) { return Base::get(context); } }; +/// Python 'tuple' type. +class TupleType : public Type::TypeBase { +public: + using Base::Base; + static TupleType get(MLIRContext *context) { return Base::get(context); } +}; + /// An unknown type that could be any supported python type. class UnknownType : public Type::TypeBase { diff --git a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td index 226890454..23c7a9acf 100644 --- a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td +++ b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td @@ -19,7 +19,7 @@ def Basicpy_Dialect : Dialect { let name = "basicpy"; let summary = "Basic Python dialect"; let description = [{ - Core types and ops + Core types and ops }]; let cppNamespace = "::mlir::NPCOMP::Basicpy"; } @@ -31,7 +31,7 @@ def Basicpy_Dialect : Dialect { class Basicpy_Op traits = []> : Op { let parser = [{ return parse$cppClass(parser, &result); }]; - let printer = [{ return print$cppClass(p, *this); }]; + let printer = [{ return print$cppClass(p, *this); }]; } //===----------------------------------------------------------------------===// @@ -71,7 +71,7 @@ def Basicpy_NoneType : DialectType()">, + CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">, "Slot object"> { let typeDescription = [{ Type for built-in objects which have a fixed number of slots and a type @@ -96,6 +96,47 @@ def Basicpy_UnknownType : DialectType()">, + "List type"> { + let typeDescription = [{ + A Python list type. In the non-parameterized case, there are limited + constraints on the element type or length; however, it can be refined to + include such constraints. + + As in Python, this list type represents a mutable, reference counted + object in a corresponding runtime layer. + }]; +} + +def Basicpy_TupleType : DialectType()">, + "Tuple type"> { + let typeDescription = [{ + A Python tuple type. In the non-parameterized case, there are limited + constraints on the element type or length; however, it can be refined to + include such constraints. + + As in Python, post-construction tuple's are immutable, reference counted + objects in a corresponding runtime layer. However, since they are + immutable, they can also serve as value-typed entities if their elements + are immutable. + }]; +} + +def Basicpy_DictType : DialectType()">, + "Dict type"> { + let typeDescription = [{ + A Python dict type. In the non-parameterized case, there are limited + constraints on the key or value types; however, it can be refined to + include such constraints. + + As in Python, this list type represents a mutable, reference counted + object in a corresponding runtime layer. + }]; +} + //===----------------------------------------------------------------------===// // Type predicates //===----------------------------------------------------------------------===// @@ -117,7 +158,7 @@ class Basicpy_SlotObjectOfClassArity : // Type representing a 'slice' object, which mirrors the Python built-in // slice class. -def Basicpy_SliceSlotObjectType : +def Basicpy_SliceSlotObjectType : Type>; #endif // NPCOMP_DIALECT_BASICPY_IR_BASICPY_DIALECT diff --git a/include/npcomp/Dialect/Basicpy/IR/BasicpyOps.td b/include/npcomp/Dialect/Basicpy/IR/BasicpyOps.td index e67b8e901..663b381f5 100644 --- a/include/npcomp/Dialect/Basicpy/IR/BasicpyOps.td +++ b/include/npcomp/Dialect/Basicpy/IR/BasicpyOps.td @@ -90,7 +90,7 @@ def CompareOperationAttr : StrEnumAttr< } //===----------------------------------------------------------------------===// -// Ops +// Operations //===----------------------------------------------------------------------===// def Basicpy_BinaryCompareOp : Basicpy_Op<"binary_compare", []> { @@ -154,6 +154,66 @@ def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [ let hasFolder = 1; } +// TODO: Implement ConstantLike op trait. +def Basicpy_BuildDictOp : Basicpy_Op<"build_dict", [NoSideEffect]> { + let summary = "Builds an empty dict"; + let description = [{ + This op mirrors the CPython BUILD_MAP op (note naming difference). + + Note that as with CPython, this op only builds an empty dict; however, + it is reserved in the future for it to take variadic operands to construct + with a list of key/value pairs. + }]; + let arguments = (ins + ); + let results = (outs + Basicpy_DictType:$result + ); + let assemblyFormat = "attr-dict `:` functional-type(operands, results)"; +} + +// TODO: Implement ConstantLike op trait. +def Basicpy_BuildListOp : Basicpy_Op<"build_list", [NoSideEffect]> { + let summary = "Builds a list from operands"; + let description = [{ + Constructs a new list object from its operands. + + TODO: Any allowable type can be expressed in lists; however, this should be + revisited once more of the dialect infrastructure is in place and tightened + up accordingly. At that time, appropriate constraints should be added that + both allow correct program representation and support transformations to + lower levels (i.e. allowing a wider set of types as useful for conversions). + }]; + let arguments = (ins + Variadic:$elements + ); + let results = (outs + Basicpy_ListType:$result + ); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + +// TODO: Implement ConstantLike op trait. +def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> { + let summary = "Builds a tuple from operands"; + let description = [{ + Constructs a new tuple object from its operands. + + TODO: Any allowable type can be expressed in lists; however, this should be + revisited once more of the dialect infrastructure is in place and tightened + up accordingly. At that time, appropriate constraints should be added that + both allow correct program representation and support transformations to + lower levels (i.e. allowing a wider set of types as useful for conversions). + }]; + let arguments = (ins + Variadic:$elements + ); + let results = (outs + Basicpy_TupleType:$result + ); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [ ConstantLike, NoSideEffect]> { let summary = "Constant bytes value"; diff --git a/lib/CAPI/Types.cpp b/lib/CAPI/Types.cpp index 0c8d3fc6a..fbacfb4a7 100644 --- a/lib/CAPI/Types.cpp +++ b/lib/CAPI/Types.cpp @@ -14,42 +14,96 @@ #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" using namespace mlir; -using namespace mlir::NPCOMP::Basicpy; -using namespace mlir::NPCOMP::Numpy; - -/*============================================================================*/ -/* Bool type. */ -/*============================================================================*/ - -int npcompTypeIsABool(MlirType t) { return unwrap(t).isa(); } - -MlirType npcompBoolTypeGet(MlirContext context) { - return wrap(BoolType::get(unwrap(context))); -} +using namespace mlir::NPCOMP; /*============================================================================*/ /* Any dtype type. */ /*============================================================================*/ -int npcompTypeIsAAnyDtype(MlirType t) { return unwrap(t).isa(); } +int npcompTypeIsAAnyDtype(MlirType t) { + return unwrap(t).isa(); +} MlirType npcompAnyDtypeTypeGet(MlirContext context) { - return wrap(AnyDtypeType::get(unwrap(context))); + return wrap(Numpy::AnyDtypeType::get(unwrap(context))); +} + +/*============================================================================*/ +/* Bool type. */ +/*============================================================================*/ + +int npcompTypeIsABool(MlirType t) { return unwrap(t).isa(); } + +MlirType npcompBoolTypeGet(MlirContext context) { + return wrap(Basicpy::BoolType::get(unwrap(context))); +} + +/*============================================================================*/ +/* Dict type. */ +/*============================================================================*/ + +/** Checks whether the given type is the Python "dict" type. */ +int npcompTypeIsADict(MlirType t) { return unwrap(t).isa(); } + +/** Gets the generic Python "dict" type. */ +MlirType npcompDictTypeGet(MlirContext context) { + return wrap(Basicpy::DictType::get(unwrap(context))); +} + +/*============================================================================*/ +/* List type. */ +/*============================================================================*/ + +/** Checks whether the given type is the Python "list" type. */ +int npcompTypeIsAList(MlirType t) { return unwrap(t).isa(); } + +/** Gets the generic Python "dict" type. */ +MlirType npcompListTypeGet(MlirContext context) { + return wrap(Basicpy::ListType::get(unwrap(context))); } /*============================================================================*/ /* NDArray type. */ /*============================================================================*/ -int npcompTypeIsANdArray(MlirType t) { return unwrap(t).isa(); } +int npcompTypeIsANdArray(MlirType t) { + return unwrap(t).isa(); +} MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, MlirType elementType) { llvm::ArrayRef shapeArray(shape, rank); - return wrap(NdArrayType::get(unwrap(elementType), shapeArray)); + return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray)); } MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) { - return wrap( - NdArrayType::getFromShapedType(unwrap(shapedType).cast())); + return wrap(Numpy::NdArrayType::getFromShapedType( + unwrap(shapedType).cast())); +} + +/*============================================================================*/ +/* None type. */ +/*============================================================================*/ + +/** Checks whether the given type is the type of the singleton 'None' value. */ +int npcompTypeIsANone(MlirType t) { return unwrap(t).isa(); } + +/** Gets the type of the singleton 'None'. */ +MlirType npcompNoneTypeGet(MlirContext context) { + return wrap(Basicpy::NoneType::get(unwrap(context))); +} + +/*============================================================================*/ +/* Tuple type. */ +/*============================================================================*/ + +/** Checks whether the given type is the special "any dtype" type that is used + * to signal an NDArray or tensor of unknown type. */ +int npcompTypeIsATuple(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the "any dtype" type. */ +MlirType npcompTupleTypeGet(MlirContext context) { + return wrap(Basicpy::TupleType::get(unwrap(context))); } diff --git a/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp b/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp index 388b0eb9b..221381374 100644 --- a/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp +++ b/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp @@ -20,8 +20,8 @@ void BasicpyDialect::initialize() { #define GET_OP_LIST #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc" >(); - addTypes(); + addTypes(); // TODO: Make real ops for everything we need. allowUnknownOperations(); @@ -36,8 +36,12 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const { return BoolType::get(getContext()); if (keyword == "BytesType") return BytesType::get(getContext()); + if (keyword == "DictType") + return DictType::get(getContext()); if (keyword == "EllipsisType") return EllipsisType::get(getContext()); + if (keyword == "ListType") + return ListType::get(getContext()); if (keyword == "NoneType") return NoneType::get(getContext()); if (keyword == "SlotObject") { @@ -60,6 +64,8 @@ Type BasicpyDialect::parseType(DialectAsmParser &parser) const { } if (keyword == "StrType") return StrType::get(getContext()); + if (keyword == "TupleType") + return TupleType::get(getContext()); if (keyword == "UnknownType") return UnknownType::get(getContext()); @@ -71,7 +77,9 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "BoolType"; }) .Case([&](Type) { os << "BytesType"; }) + .Case([&](Type) { os << "DictType"; }) .Case([&](Type) { os << "EllipsisType"; }) + .Case([&](Type) { os << "ListType"; }) .Case([&](Type) { os << "NoneType"; }) .Case([&](SlotObjectType slotObject) { auto slotTypes = slotObject.getSlotTypes(); @@ -84,6 +92,7 @@ void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const { os << ">"; }) .Case([&](Type) { os << "StrType"; }) + .Case([&](Type) { os << "TupleType"; }) .Case([&](Type) { os << "UnknownType"; }) .Default( [&](Type) { llvm_unreachable("unexpected 'basicpy' type kind"); }); diff --git a/test/Dialect/Basicpy/ops.mlir b/test/Dialect/Basicpy/ops.mlir new file mode 100644 index 000000000..1856a302e --- /dev/null +++ b/test/Dialect/Basicpy/ops.mlir @@ -0,0 +1,27 @@ +// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s + +// ----- +// CHECK-LABEL: @build_dict_generic +func @build_dict_generic() -> !basicpy.DictType { + // CHECK: basicpy.build_dict : () -> !basicpy.DictType + %0 = basicpy.build_dict : () -> !basicpy.DictType + return %0 : !basicpy.DictType +} + +// ----- +// CHECK-LABEL: @build_list_generic +func @build_list_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.ListType { + // CHECK: basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType + %0 = basicpy.build_list %arg0, %arg1 : (si32, si32) -> !basicpy.ListType + return %0 : !basicpy.ListType +} + +// ----- +// CHECK-LABEL: @build_tuple_generic +func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType { + // CHECK: basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType + %0 = basicpy.build_tuple %arg0, %arg1 : (si32, si32) -> !basicpy.TupleType + return %0 : !basicpy.TupleType +} + +