mirror of https://github.com/llvm/torch-mlir
Add NdArrayType.
parent
bccfd5f6fc
commit
efe8915901
|
@ -17,10 +17,18 @@ namespace NPCOMP {
|
||||||
namespace Numpy {
|
namespace Numpy {
|
||||||
|
|
||||||
namespace NumpyTypes {
|
namespace NumpyTypes {
|
||||||
enum Kind { AnyDtypeType = TypeRanges::Numpy, LAST_NUMPY_TYPE = AnyDtypeType };
|
enum Kind {
|
||||||
|
AnyDtypeType = TypeRanges::Numpy,
|
||||||
|
NdArray,
|
||||||
|
LAST_NUMPY_TYPE = AnyDtypeType,
|
||||||
|
};
|
||||||
} // namespace NumpyTypes
|
} // namespace NumpyTypes
|
||||||
|
|
||||||
// The singleton type representing an unknown dtype.
|
namespace detail {
|
||||||
|
struct NdArrayTypeStorage;
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// The singleton type representing an unknown dtype.
|
||||||
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
|
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
@ -34,6 +42,15 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class NdArrayType
|
||||||
|
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
|
||||||
|
static NdArrayType get(Type optionalDtype, MLIRContext *context);
|
||||||
|
Type getOptionalDtype();
|
||||||
|
};
|
||||||
|
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
|
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
|
||||||
|
|
||||||
} // namespace Numpy
|
} // namespace Numpy
|
||||||
|
|
|
@ -32,7 +32,7 @@ def Numpy_Dialect : Dialect {
|
||||||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<Numpy_Dialect, mnemonic, traits> {
|
Op<Numpy_Dialect, mnemonic, traits> {
|
||||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||||
let printer = [{ return print$cppClass(p, *this); }];
|
let printer = [{ return print$cppClass(p, *this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -47,6 +47,43 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Numpy_NdArrayType : DialectType<Numpy_Dialect,
|
||||||
|
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::NdArrayType>()">, "ndarray type">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::NPCOMP::Numpy::NdArrayType>()"> {
|
||||||
|
let typeDescription = [{
|
||||||
|
NdArrayType: Models a numpy.ndarray and compatible types.
|
||||||
|
Unlike lower level representations, this type solely exists to represent
|
||||||
|
top-level semantics and source-dialect transformations. As such, it
|
||||||
|
is not a general modeling like `tensor` or `memref`, instead being just
|
||||||
|
enough to infer proper lowerings to those types.
|
||||||
|
|
||||||
|
Like its numpy counterparts, NdArrayType represents a mutable array of
|
||||||
|
some value type (dtype), with a shape, strides, and various controls
|
||||||
|
around contiguity. Most of that is not modeled in this type, which focuses
|
||||||
|
on a representation sufficient to infer high level types and aliasing
|
||||||
|
based on program flow.
|
||||||
|
|
||||||
|
Note that most operations in numpy can be legally defined similar to the
|
||||||
|
following:
|
||||||
|
%0 = ... -> !numpy.ndarray<...>
|
||||||
|
%1 = numpy.copy_to_tensor %0 -> tensor<...>
|
||||||
|
%2 = numpy.some_operation %1
|
||||||
|
%4 = numpy.copy_from_tensor -> !numpy.ndarray<...>
|
||||||
|
|
||||||
|
(in other words, the operation does not alias any of its operands to its
|
||||||
|
results)
|
||||||
|
|
||||||
|
When this is the case, the operation will *only* be defined for tensors,
|
||||||
|
as staying in the value domain makes sense for as many operations as
|
||||||
|
can be reasonably represented as such. It is left to subsequent parts of
|
||||||
|
the compiler to transform the program in such a way as to elide the copies
|
||||||
|
that such sequences encode.
|
||||||
|
|
||||||
|
Only ops that mutate or alias their operands should accept and/or produce
|
||||||
|
ndarray types.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -57,7 +94,7 @@ def Numpy_AnyArray : TensorOf<[AnyType]>;
|
||||||
def Numpy_SliceTupleElement : AnyTypeOf<[
|
def Numpy_SliceTupleElement : AnyTypeOf<[
|
||||||
// Supports both "Index Arrays" and "Boolean mask index arrays".
|
// Supports both "Index Arrays" and "Boolean mask index arrays".
|
||||||
Numpy_AnyArray,
|
Numpy_AnyArray,
|
||||||
|
|
||||||
// Indicates that an axis should be added (np.newaxis == None).
|
// Indicates that an axis should be added (np.newaxis == None).
|
||||||
Basicpy_NoneType,
|
Basicpy_NoneType,
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.cpp.inc"
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addTypes<AnyDtypeType>();
|
addTypes<AnyDtypeType, NdArrayType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
@ -29,6 +29,22 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
|
||||||
if (keyword == "any_dtype")
|
if (keyword == "any_dtype")
|
||||||
return AnyDtypeType::get(getContext());
|
return AnyDtypeType::get(getContext());
|
||||||
|
if (keyword == "ndarray") {
|
||||||
|
// Parse:
|
||||||
|
// ndarray<?>
|
||||||
|
// ndarray<i32>
|
||||||
|
Type dtype;
|
||||||
|
if (parser.parseLess())
|
||||||
|
return Type();
|
||||||
|
if (failed(parser.parseOptionalQuestion())) {
|
||||||
|
// Specified dtype.
|
||||||
|
if (parser.parseType(dtype))
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
if (parser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return NdArrayType::get(dtype, getContext());
|
||||||
|
}
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||||
return Type();
|
return Type();
|
||||||
|
@ -39,7 +55,53 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||||
case NumpyTypes::AnyDtypeType:
|
case NumpyTypes::AnyDtypeType:
|
||||||
os << "any_dtype";
|
os << "any_dtype";
|
||||||
return;
|
return;
|
||||||
|
case NumpyTypes::NdArray: {
|
||||||
|
auto ndarray = type.cast<NdArrayType>();
|
||||||
|
auto dtype = ndarray.getOptionalDtype();
|
||||||
|
os << "ndarray<";
|
||||||
|
if (dtype)
|
||||||
|
os.printType(dtype);
|
||||||
|
else
|
||||||
|
os << "?";
|
||||||
|
os << ">";
|
||||||
|
return;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("unexpected 'numpy' type kind");
|
llvm_unreachable("unexpected 'numpy' type kind");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
// Type and attribute detail
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace Numpy {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct NdArrayTypeStorage : public TypeStorage {
|
||||||
|
using KeyTy = Type;
|
||||||
|
NdArrayTypeStorage(Type optionalDtype) : optionalDtype(optionalDtype) {}
|
||||||
|
bool operator==(const KeyTy &other) const { return optionalDtype == other; }
|
||||||
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||||
|
return llvm::hash_combine(key);
|
||||||
|
}
|
||||||
|
static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||||
|
const KeyTy &key) {
|
||||||
|
return new (allocator.allocate<NdArrayTypeStorage>())
|
||||||
|
NdArrayTypeStorage(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type optionalDtype;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace Numpy
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
NdArrayType NdArrayType::get(Type optionalDtype, MLIRContext *context) {
|
||||||
|
return Base::get(context, NumpyTypes::NdArray, optionalDtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type NdArrayType::getOptionalDtype() { return getImpl()->optionalDtype; }
|
||||||
|
|
Loading…
Reference in New Issue