mirror of https://github.com/llvm/torch-mlir
Add NdArrayType.
parent
bccfd5f6fc
commit
efe8915901
|
@ -17,10 +17,18 @@ namespace NPCOMP {
|
|||
namespace Numpy {
|
||||
|
||||
namespace NumpyTypes {
|
||||
enum Kind { AnyDtypeType = TypeRanges::Numpy, LAST_NUMPY_TYPE = AnyDtypeType };
|
||||
enum Kind {
|
||||
AnyDtypeType = TypeRanges::Numpy,
|
||||
NdArray,
|
||||
LAST_NUMPY_TYPE = AnyDtypeType,
|
||||
};
|
||||
} // namespace NumpyTypes
|
||||
|
||||
// The singleton type representing an unknown dtype.
|
||||
namespace detail {
|
||||
struct NdArrayTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// The singleton type representing an unknown dtype.
|
||||
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
@ -34,6 +42,15 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class NdArrayType
|
||||
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
|
||||
static NdArrayType get(Type optionalDtype, MLIRContext *context);
|
||||
Type getOptionalDtype();
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
|
||||
|
||||
} // namespace Numpy
|
||||
|
|
|
@ -32,7 +32,7 @@ def Numpy_Dialect : Dialect {
|
|||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Numpy_Dialect, mnemonic, traits> {
|
||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -47,6 +47,43 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
|||
}];
|
||||
}
|
||||
|
||||
def Numpy_NdArrayType : DialectType<Numpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::NdArrayType>()">, "ndarray type">,
|
||||
BuildableType<"$_builder.getType<::mlir::NPCOMP::Numpy::NdArrayType>()"> {
|
||||
let typeDescription = [{
|
||||
NdArrayType: Models a numpy.ndarray and compatible types.
|
||||
Unlike lower level representations, this type solely exists to represent
|
||||
top-level semantics and source-dialect transformations. As such, it
|
||||
is not a general modeling like `tensor` or `memref`, instead being just
|
||||
enough to infer proper lowerings to those types.
|
||||
|
||||
Like its numpy counterparts, NdArrayType represents a mutable array of
|
||||
some value type (dtype), with a shape, strides, and various controls
|
||||
around contiguity. Most of that is not modeled in this type, which focuses
|
||||
on a representation sufficient to infer high level types and aliasing
|
||||
based on program flow.
|
||||
|
||||
Note that most operations in numpy can be legally defined similar to the
|
||||
following:
|
||||
%0 = ... -> !numpy.ndarray<...>
|
||||
%1 = numpy.copy_to_tensor %0 -> tensor<...>
|
||||
%2 = numpy.some_operation %1
|
||||
%4 = numpy.copy_from_tensor -> !numpy.ndarray<...>
|
||||
|
||||
(in other words, the operation does not alias any of its operands to its
|
||||
results)
|
||||
|
||||
When this is the case, the operation will *only* be defined for tensors,
|
||||
as staying in the value domain makes sense for as many operations as
|
||||
can be reasonably represented as such. It is left to subsequent parts of
|
||||
the compiler to transform the program in such a way as to elide the copies
|
||||
that such sequences encode.
|
||||
|
||||
Only ops that mutate or alias their operands should accept and/or produce
|
||||
ndarray types.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -57,7 +94,7 @@ def Numpy_AnyArray : TensorOf<[AnyType]>;
|
|||
def Numpy_SliceTupleElement : AnyTypeOf<[
|
||||
// Supports both "Index Arrays" and "Boolean mask index arrays".
|
||||
Numpy_AnyArray,
|
||||
|
||||
|
||||
// Indicates that an axis should be added (np.newaxis == None).
|
||||
Basicpy_NoneType,
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.cpp.inc"
|
||||
>();
|
||||
addTypes<AnyDtypeType>();
|
||||
addTypes<AnyDtypeType, NdArrayType>();
|
||||
}
|
||||
|
||||
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
||||
|
@ -29,6 +29,22 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
|
||||
if (keyword == "any_dtype")
|
||||
return AnyDtypeType::get(getContext());
|
||||
if (keyword == "ndarray") {
|
||||
// Parse:
|
||||
// ndarray<?>
|
||||
// ndarray<i32>
|
||||
Type dtype;
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
if (failed(parser.parseOptionalQuestion())) {
|
||||
// Specified dtype.
|
||||
if (parser.parseType(dtype))
|
||||
return Type();
|
||||
}
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return NdArrayType::get(dtype, getContext());
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||
return Type();
|
||||
|
@ -39,7 +55,53 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
case NumpyTypes::AnyDtypeType:
|
||||
os << "any_dtype";
|
||||
return;
|
||||
case NumpyTypes::NdArray: {
|
||||
auto ndarray = type.cast<NdArrayType>();
|
||||
auto dtype = ndarray.getOptionalDtype();
|
||||
os << "ndarray<";
|
||||
if (dtype)
|
||||
os.printType(dtype);
|
||||
else
|
||||
os << "?";
|
||||
os << ">";
|
||||
return;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unexpected 'numpy' type kind");
|
||||
}
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Type and attribute detail
|
||||
//----------------------------------------------------------------------------//
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Numpy {
|
||||
namespace detail {
|
||||
|
||||
struct NdArrayTypeStorage : public TypeStorage {
|
||||
using KeyTy = Type;
|
||||
NdArrayTypeStorage(Type optionalDtype) : optionalDtype(optionalDtype) {}
|
||||
bool operator==(const KeyTy &other) const { return optionalDtype == other; }
|
||||
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key);
|
||||
}
|
||||
static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<NdArrayTypeStorage>())
|
||||
NdArrayTypeStorage(key);
|
||||
}
|
||||
|
||||
Type optionalDtype;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
NdArrayType NdArrayType::get(Type optionalDtype, MLIRContext *context) {
|
||||
return Base::get(context, NumpyTypes::NdArray, optionalDtype);
|
||||
}
|
||||
|
||||
Type NdArrayType::getOptionalDtype() { return getImpl()->optionalDtype; }
|
||||
|
|
Loading…
Reference in New Issue