mirror of https://github.com/llvm/torch-mlir
Add basicpy.SlotObject type and ops to create/index into it.
* This is intended to provide low-level modeling for built-in objects. * It is now possible to trace slice tuples (which are tuples of NoneType|EllipsisType|SlotObjectType<slice, ...>).pull/1/head
parent
bfd5fedba7
commit
bc5ef81d68
|
@ -9,7 +9,9 @@
|
|||
#ifndef NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT_H
|
||||
#define NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "npcomp/Dialect/Common.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -18,11 +20,60 @@ namespace Basicpy {
|
|||
|
||||
namespace BasicpyTypes {
|
||||
enum Kind {
|
||||
PlaceholderType = TypeRanges::Basicpy,
|
||||
LAST_BASICPY_TYPE = PlaceholderType
|
||||
// Dialect types.
|
||||
NoneType = TypeRanges::Basicpy,
|
||||
EllipsisType,
|
||||
SlotObjectType,
|
||||
|
||||
// Dialect attributes.
|
||||
SingletonAttr,
|
||||
LAST_BASICPY_TYPE = SingletonAttr,
|
||||
};
|
||||
} // namespace BasicpyTypes
|
||||
|
||||
namespace detail {
|
||||
struct SlotObjectTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// The type of the Python `None` value.
|
||||
class NoneType : public Type::TypeBase<NoneType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == BasicpyTypes::NoneType; }
|
||||
static NoneType get(MLIRContext *context) {
|
||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||
// of this type.
|
||||
return Base::get(context, BasicpyTypes::NoneType);
|
||||
}
|
||||
};
|
||||
|
||||
/// The type of the Python `Ellipsis` value.
|
||||
class EllipsisType : public Type::TypeBase<EllipsisType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == BasicpyTypes::EllipsisType;
|
||||
}
|
||||
static EllipsisType get(MLIRContext *context) {
|
||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||
// of this type.
|
||||
return Base::get(context, BasicpyTypes::EllipsisType);
|
||||
}
|
||||
};
|
||||
|
||||
class SlotObjectType : public Type::TypeBase<SlotObjectType, Type,
|
||||
detail::SlotObjectTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == BasicpyTypes::SlotObjectType;
|
||||
}
|
||||
static SlotObjectType get(StringAttr className, ArrayRef<Type> slotTypes);
|
||||
StringAttr getClassName();
|
||||
unsigned getSlotCount();
|
||||
ArrayRef<Type> getSlotTypes();
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOpsDialect.h.inc"
|
||||
|
||||
} // namespace Basicpy
|
||||
|
|
|
@ -38,10 +38,39 @@ class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// Dialect types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_SlotObjectType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::SlotObjectType>()">,
|
||||
"Slot object"> {
|
||||
let typeDescription = [{
|
||||
Type for built-in objects which have a fixed number of slots and a type
|
||||
name in the system catalog of types. In some ways, this resembles a
|
||||
namedtuple, but it is used for built-in custom objects.
|
||||
}];
|
||||
}
|
||||
|
||||
def Basicpy_NoneType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::NoneType>()">, "None type">,
|
||||
BuildableType<"$_builder.getType<::mlir::NPCOMP::Basicpy::NoneType>()"> {
|
||||
let typeDescription = [{
|
||||
Type of the Python 'None' value.
|
||||
}];
|
||||
}
|
||||
|
||||
def Basicpy_EllipsisType : DialectType<Basicpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Basicpy::EllipsisType>()">, "Ellipsis type">,
|
||||
BuildableType<"$_builder.getType<::mlir::NPCOMP::Basicpy::EllipsisType>()"> {
|
||||
let typeDescription = [{
|
||||
Type of the Python 'Ellipsis' value.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_SingletonType : AnyTypeOf<[
|
||||
Basicpy_NoneType,
|
||||
Basicpy_EllipsisType
|
||||
]>;
|
||||
|
||||
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT
|
||||
|
|
|
@ -13,16 +13,62 @@ include "BasicpyDialect.td"
|
|||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
def Basicpy_ExampleOp : Basicpy_Op<"example", []> {
|
||||
let summary = "Move along, nothing to see here.";
|
||||
def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
let summary = "Constant value for a singleton type";
|
||||
let description = [{
|
||||
Some types only have a single possible value, represented by the
|
||||
SingletonAttr. This op allows creating constants of these types.
|
||||
}];
|
||||
let arguments = (ins);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
let results = (outs
|
||||
Basicpy_SingletonType:$result
|
||||
);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def Basicpy_SlotObjectMakeOp : Basicpy_Op<"slot_object_make", [
|
||||
NoSideEffect]> {
|
||||
let summary = "Creates an instance of a SlotObject type";
|
||||
let description = [{
|
||||
SlotObjects are typically instances of built-in classes that have a fixed
|
||||
number of slots. Unlike in standard python, the types of each slot are
|
||||
tracked.
|
||||
|
||||
This op has a custom assembly form which can be used when valid that
|
||||
omits the operand types (since they are equal to the types in the returned
|
||||
slot object). Example:
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
%1 = basicpy.slot_object_make(%0) ->
|
||||
!basicpy.SlotObject<slice, !basicpy.NoneType>
|
||||
}];
|
||||
let arguments = (ins
|
||||
StrAttr:$className,
|
||||
// TODO: Tighter constraints on allowable types.
|
||||
Variadic<AnyType>:$slots
|
||||
);
|
||||
let results = (outs
|
||||
Basicpy_SlotObjectType:$result
|
||||
);
|
||||
}
|
||||
|
||||
def Basicpy_SlotObjectGetOp : Basicpy_Op<"slot_object_get", [
|
||||
NoSideEffect]> {
|
||||
let summary = "Gets a slot from a slot object";
|
||||
let description = [{
|
||||
Gets a slot from a SlotObject.
|
||||
|
||||
Example:
|
||||
%0 = basicpy.slot_object_make ...
|
||||
%1 = basicpy.slot_object_get %0[1] : !basicpy.SlotObject<...>
|
||||
}];
|
||||
let arguments = (ins
|
||||
Basicpy_SlotObjectType:$object,
|
||||
IndexAttr:$index
|
||||
);
|
||||
let results = (outs
|
||||
AnyType:$result
|
||||
);
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_OPS
|
||||
|
|
|
@ -40,7 +40,7 @@ class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::AnyDtypeType>()">, "any dtype">,
|
||||
BuildableType<"$_builder.getType::mlir::NPCOMP::Numpy::AnyDtypeType()"> {
|
||||
BuildableType<"$_builder.getType<::mlir::NPCOMP::Numpy::AnyDtypeType>()"> {
|
||||
let typeDescription = [{
|
||||
Placeholder for an unknown dtype in a tensor.
|
||||
}];
|
||||
|
|
|
@ -19,17 +19,105 @@ BasicpyDialect::BasicpyDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOps.cpp.inc"
|
||||
>();
|
||||
// addTypes<AnyDtypeType>();
|
||||
addTypes<NoneType, EllipsisType, SlotObjectType>();
|
||||
}
|
||||
|
||||
// Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
||||
// parser.emitError(parser.getNameLoc(), "unknown numpy type");
|
||||
// return Type();
|
||||
// }
|
||||
Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
// void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
// switch (type.getKind()) {
|
||||
// default:
|
||||
// llvm_unreachable("unexpected 'basicpy' type kind");
|
||||
// }
|
||||
// }
|
||||
if (keyword == "NoneType")
|
||||
return NoneType::get(getContext());
|
||||
if (keyword == "EllipsisType")
|
||||
return EllipsisType::get(getContext());
|
||||
if (keyword == "SlotObject") {
|
||||
StringRef className;
|
||||
unsigned slotCount;
|
||||
if (parser.parseLess() || parser.parseKeyword(&className)) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type, 4> slotTypes;
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
Type slotType;
|
||||
if (parser.parseType(slotType))
|
||||
return Type();
|
||||
slotTypes.push_back(slotType);
|
||||
}
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return SlotObjectType::get(StringAttr::get(className, getContext()),
|
||||
slotTypes);
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown basicpy type");
|
||||
return Type();
|
||||
}
|
||||
|
||||
void BasicpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case BasicpyTypes::NoneType:
|
||||
os << "NoneType";
|
||||
return;
|
||||
case BasicpyTypes::EllipsisType:
|
||||
os << "EllipsisType";
|
||||
return;
|
||||
case BasicpyTypes::SlotObjectType: {
|
||||
auto slotObject = type.cast<SlotObjectType>();
|
||||
auto slotTypes = slotObject.getSlotTypes();
|
||||
os << "SlotObject<" << slotObject.getClassName().getValue();
|
||||
if (!slotTypes.empty()) {
|
||||
os << ", ";
|
||||
llvm::interleaveComma(slotTypes, os, [&](Type t) { os.printType(t); });
|
||||
}
|
||||
os << ">";
|
||||
return;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unexpected 'basicpy' type kind");
|
||||
}
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Type and attribute detail
|
||||
//----------------------------------------------------------------------------//
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Basicpy {
|
||||
namespace detail {
|
||||
|
||||
struct SlotObjectTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::pair<StringAttr, ArrayRef<Type>>;
|
||||
SlotObjectTypeStorage(StringAttr className, ArrayRef<Type> slotTypes)
|
||||
: className(className), slotTypes(slotTypes) {}
|
||||
bool operator==(const KeyTy &other) const {
|
||||
return className == other.first && slotTypes == other.second;
|
||||
}
|
||||
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.first, key.second);
|
||||
}
|
||||
static SlotObjectTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
ArrayRef<Type> slotTypes = allocator.copyInto(key.second);
|
||||
return new (allocator.allocate<SlotObjectTypeStorage>())
|
||||
SlotObjectTypeStorage(key.first, slotTypes);
|
||||
}
|
||||
|
||||
StringAttr className;
|
||||
ArrayRef<Type> slotTypes;
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace Basicpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
StringAttr SlotObjectType::getClassName() { return getImpl()->className; }
|
||||
ArrayRef<Type> SlotObjectType::getSlotTypes() { return getImpl()->slotTypes; }
|
||||
unsigned SlotObjectType::getSlotCount() { return getImpl()->slotTypes.size(); }
|
||||
|
||||
SlotObjectType SlotObjectType::get(StringAttr className,
|
||||
ArrayRef<Type> slotTypes) {
|
||||
return Base::get(className.getContext(), BasicpyTypes::SlotObjectType,
|
||||
className, slotTypes);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,125 @@
|
|||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Basicpy {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlotObjectMakeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSlotObjectMakeOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 4> operandTypes;
|
||||
if (parser.parseOperandList(operandTypes, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseOptionalAttrDict(result->attributes) ||
|
||||
parser.parseArrowTypeList(result->types)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (result->types.size() != 1 ||
|
||||
!result->types.front().isa<SlotObjectType>()) {
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"custom assembly form requires SlotObject result");
|
||||
}
|
||||
auto slotObjectType = result->types.front().cast<SlotObjectType>();
|
||||
result->addAttribute("className", slotObjectType.getClassName());
|
||||
return parser.resolveOperands(operandTypes, slotObjectType.getSlotTypes(),
|
||||
parser.getNameLoc(), result->operands);
|
||||
}
|
||||
|
||||
static void printSlotObjectMakeOp(OpAsmPrinter &p, SlotObjectMakeOp op) {
|
||||
// If the argument types do not match the result type slots, then
|
||||
// print the generic form.
|
||||
auto canCustomPrint = ([&]() -> bool {
|
||||
auto type = op.result().getType().dyn_cast<SlotObjectType>();
|
||||
if (!type)
|
||||
return false;
|
||||
auto args = op.slots();
|
||||
auto slotTypes = type.getSlotTypes();
|
||||
if (args.size() != slotTypes.size())
|
||||
return false;
|
||||
for (unsigned i = 0, e = args.size(); i < e; ++i) {
|
||||
if (args[i].getType() != slotTypes[i])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})();
|
||||
if (!canCustomPrint) {
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
p << op.getOperationName() << "(";
|
||||
p.printOperands(op.slots());
|
||||
p << ")";
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"className"});
|
||||
|
||||
// Not really a symbol but satisfies same rules.
|
||||
p.printArrowTypeList(op.getOperation()->getResultTypes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlotObjectGetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSlotObjectGetOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType object;
|
||||
IntegerAttr indexAttr;
|
||||
Type indexType = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperand(object) || parser.parseLSquare() ||
|
||||
parser.parseAttribute(indexAttr, indexType, "index",
|
||||
result->attributes) ||
|
||||
parser.parseRSquare()) {
|
||||
return failure();
|
||||
}
|
||||
Type objectType;
|
||||
if (parser.parseColonType(objectType) ||
|
||||
parser.resolveOperand(object, objectType, result->operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto castObjectType = objectType.dyn_cast<SlotObjectType>();
|
||||
if (!castObjectType) {
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"illegal object type on custom assembly form");
|
||||
}
|
||||
auto index = indexAttr.getValue().getZExtValue();
|
||||
auto slotTypes = castObjectType.getSlotTypes();
|
||||
if (index >= slotTypes.size()) {
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"out of bound index on custom assembly form");
|
||||
}
|
||||
result->addTypes({slotTypes[index]});
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printSlotObjectGetOp(OpAsmPrinter &p, SlotObjectGetOp op) {
|
||||
// If the argument types do not match the result type slots, then
|
||||
// print the generic form.
|
||||
auto canCustomPrint = ([&]() -> bool {
|
||||
auto type = op.object().getType().dyn_cast<SlotObjectType>();
|
||||
if (!type)
|
||||
return false;
|
||||
auto index = op.index().getZExtValue();
|
||||
if (index >= type.getSlotCount())
|
||||
return false;
|
||||
if (op.result().getType() != type.getSlotTypes()[index])
|
||||
return false;
|
||||
return true;
|
||||
})();
|
||||
if (!canCustomPrint) {
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
p << op.getOperationName() << " ";
|
||||
p.printOperand(op.object());
|
||||
p << "[" << op.index() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"index"});
|
||||
p << " : ";
|
||||
p.printType(op.object().getType());
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOps.cpp.inc"
|
||||
} // namespace Basicpy
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @slot_object_make
|
||||
func @slot_object_make() -> (!basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>) {
|
||||
// CHECK: %[[N:.+]] = basicpy.singleton
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
// CHECK: basicpy.slot_object_make(%[[N]], %[[N]], %[[N]]) -> !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>
|
||||
%1 = "basicpy.slot_object_make"(%0, %0, %0) {className = "slice" } :
|
||||
(!basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType) ->
|
||||
(!basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>)
|
||||
return %1 : !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @slot_object_get() -> (!basicpy.NoneType) {
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
// CHECK: %[[OBJ:.+]] = basicpy.slot_object_make
|
||||
%1 = basicpy.slot_object_make(%0, %0) -> (!basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType>)
|
||||
// CHECK: basicpy.slot_object_get %[[OBJ]][1] : !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType>
|
||||
%2 = basicpy.slot_object_get %1[1] : !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType>
|
||||
return %2 : !basicpy.NoneType
|
||||
}
|
||||
|
||||
// TODO: Verify illegal forms
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @const_none
|
||||
func @const_none() -> (!basicpy.NoneType) {
|
||||
// CHECK: basicpy.singleton : !basicpy.NoneType
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
return %0 : !basicpy.NoneType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @const_ellipsis
|
||||
func @const_ellipsis() -> (!basicpy.EllipsisType) {
|
||||
// CHECK: basicpy.singleton : !basicpy.EllipsisType
|
||||
%0 = basicpy.singleton : !basicpy.EllipsisType
|
||||
return %0 : !basicpy.EllipsisType
|
||||
}
|
Loading…
Reference in New Issue