mirror of https://github.com/llvm/torch-mlir
154 lines
4.9 KiB
C++
154 lines
4.9 KiB
C++
//===- NumpyOps.cpp - Core numpy dialect ops --------------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
|
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
|
|
|
namespace mlir {
|
|
namespace NPCOMP {
|
|
namespace Numpy {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BuildinUfuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseBuiltinUfuncOp(OpAsmParser &parser,
|
|
OperationState *result) {
|
|
// Parse the name as a symbol.
|
|
StringAttr nameAttr;
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
result->attributes)) {
|
|
return failure();
|
|
}
|
|
if (failed(parser.parseOptionalAttrDict(result->attributes))) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void printBuiltinUfuncOp(OpAsmPrinter &p, BuiltinUfuncOp op) {
|
|
p << op.getOperationName() << " ";
|
|
p.printSymbolName(op.getName());
|
|
p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName()});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericUfuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseGenericUfuncOp(OpAsmParser &parser,
|
|
OperationState *result) {
|
|
Builder b(result->getContext());
|
|
|
|
// Parse the name as a symbol.
|
|
StringAttr nameAttr;
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
result->attributes))
|
|
return failure();
|
|
|
|
// Parse the body of overloads.
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
SmallVector<Attribute, 4> overloadTypes;
|
|
for (bool first = true;; first = false) {
|
|
if (first) {
|
|
if (parser.parseOptionalKeyword("overload"))
|
|
break;
|
|
}
|
|
if (!first) {
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
if (parser.parseKeyword("overload"))
|
|
return failure();
|
|
}
|
|
SmallVector<OpAsmParser::OperandType, 2> argNames;
|
|
SmallVector<Type, 2> argTypes;
|
|
SmallVector<Type, 1> resultTypes;
|
|
SmallVector<NamedAttrList, 1> unusedAttrs;
|
|
bool isVariadic = false;
|
|
if (::mlir::impl::parseFunctionSignature(parser, false, argNames, argTypes,
|
|
unusedAttrs, isVariadic,
|
|
resultTypes, unusedAttrs))
|
|
return failure();
|
|
overloadTypes.push_back(TypeAttr::get(
|
|
FunctionType::get(argTypes, resultTypes, result->getContext())));
|
|
auto *region = result->addRegion();
|
|
if (parser.parseRegion(*region, argNames, argTypes))
|
|
return failure();
|
|
}
|
|
|
|
if (parser.parseRParen())
|
|
return failure();
|
|
|
|
// Parse 'attributes {...}'
|
|
if (parser.parseOptionalAttrDictWithKeyword(result->attributes))
|
|
return failure();
|
|
result->addAttribute(b.getIdentifier("overload_types"),
|
|
b.getArrayAttr(overloadTypes));
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
|
|
p << op.getOperationName() << " @" << op.getName() << "(";
|
|
bool first = true;
|
|
for (auto it : llvm::enumerate(op.getRegions())) {
|
|
auto *region = it.value();
|
|
if (first)
|
|
first = false;
|
|
else
|
|
p << ", ";
|
|
if (region->empty()) {
|
|
p << "<<OVERLOAD_ERROR>>";
|
|
continue;
|
|
}
|
|
|
|
Block &entryBlock = region->front();
|
|
p << "overload(";
|
|
if (it.index() >= op.overload_types().size()) {
|
|
p << "<<MISSING OVERLOAD TYPE>>";
|
|
continue;
|
|
}
|
|
TypeAttr tattr = op.overload_types()[it.index()].cast<TypeAttr>();
|
|
FunctionType overloadType = tattr.getValue().dyn_cast<FunctionType>();
|
|
if (!overloadType) {
|
|
p << "<<ILLEGAL OVERLOAD TYPE>>";
|
|
continue;
|
|
}
|
|
if (overloadType.getNumInputs() != entryBlock.getNumArguments()) {
|
|
p << "<<OVERLOAD ARG MISMATCH>>";
|
|
continue;
|
|
}
|
|
|
|
auto argTypes = entryBlock.getArgumentTypes();
|
|
for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
|
|
auto arg = entryBlock.getArgument(i);
|
|
if (i > 0)
|
|
p << ", ";
|
|
p.printOperand(arg);
|
|
p << ": ";
|
|
p.printType(overloadType.getInputs()[i]);
|
|
}
|
|
p << ")";
|
|
p.printArrowTypeList(overloadType.getResults());
|
|
p.printRegion(*region, false, true);
|
|
}
|
|
p << ")";
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
|
} // namespace Numpy
|
|
} // namespace NPCOMP
|
|
} // namespace mlir
|