//===- ATenDialect.cpp ------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Dialect/ATen/IR/ATenDialect.h" #include "mlir/IR/DialectImplementation.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::aten; namespace mlir { namespace NPCOMP { namespace aten { namespace detail { /// This class holds the implementation of the ATenListType. /// It is intended to be uniqued based on its content and owned by the context. struct ATenListTypeStorage : public mlir::TypeStorage { ATenListTypeStorage(Type elementType) : elementType(elementType) {} /// The hash key used for uniquing. using KeyTy = mlir::Type; bool operator==(const KeyTy &key) const { return key == getElementType(); } /// This is a factory method to create our type storage. It is only /// invoked after looking up the type in the context using the key and not /// finding it. static ATenListTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) { // Allocate the instance for the ATenListTypeStorage itself auto *storage = allocator.allocate(); // Initialize the instance using placement new. return new (storage) ATenListTypeStorage(key); } Type getElementType() const { return elementType; } private: Type elementType; }; } // namespace detail ATenListType ATenListType::get(mlir::Type elemType) { return Base::get(elemType.getContext(), elemType); } mlir::Type ATenListType::getElementType() { return getImpl()->getElementType(); } mlir::Type ATenDialect::parseType(DialectAsmParser &parser) const { // All types start with an identifier that we switch on. StringRef typeNameSpelling; if (failed(parser.parseKeyword(&typeNameSpelling))) return nullptr; if (typeNameSpelling == "list") { if (failed(parser.parseLess())) return nullptr; Type t; if (failed(parser.parseType(t))) return nullptr; if (failed(parser.parseGreater())) return nullptr; return ATenListType::get(t); } parser.emitError(parser.getCurrentLocation(), "Invalid ATen type '" + typeNameSpelling + "'"); return nullptr; } /// Print a ATenListType void ATenDialect::printType(mlir::Type type, DialectAsmPrinter &os) const { auto ty = type.dyn_cast(); if (!ty) { os << "unknown aten type"; return; } os << "list<"; os.getStream() << ty.getElementType(); os << ">"; } } // namespace aten } // namespace NPCOMP } // namespace mlir Operation *ATenDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Bool (i1 -> !basicpy.BoolType). if (type.isa()) { auto i1Value = value.dyn_cast(); if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1) return builder.create(loc, type, i1Value); } return nullptr; } void ATenDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST #include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" >(); getContext()->getOrLoadDialect("torch"); getContext()->getOrLoadDialect("std"); } #include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"