mirror of https://github.com/llvm/torch-mlir
Add basicpy.numeric_constant op.
* Going through TODOs on the PyTorch side, this is a big cause of them (not being able to have constants for signed/unsigned). * Added complex while in here since we're at the phase where it is better to just have things complete than partially done.pull/128/head
parent
bea0af419d
commit
3937dd14cb
|
@ -17,7 +17,8 @@ using namespace torch_mlir;
|
|||
static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
||||
MlirAttribute value) {
|
||||
OperationStateHolder s("std.constant", loc);
|
||||
MlirNamedAttribute valueAttr = mlirNamedAttributeGet(toMlirStringRef("value"), value);
|
||||
MlirNamedAttribute valueAttr =
|
||||
mlirNamedAttributeGet(toMlirStringRef("value"), value);
|
||||
mlirOperationStateAddResults(s, 1, &type);
|
||||
mlirOperationStateAddAttributes(s, 1, &valueAttr);
|
||||
return s.createOperation();
|
||||
|
@ -44,12 +45,15 @@ void KernelCallBuilder::addSchemaAttrs() {
|
|||
// sigIsVarret
|
||||
// sigIsMutable
|
||||
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("sigIsMutable"), mlirBoolAttrGet(context, schema.is_mutable())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("sigIsVararg"), mlirBoolAttrGet(context, schema.is_vararg())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("sigIsVarret"), mlirBoolAttrGet(context, schema.is_varret())));
|
||||
attrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"),
|
||||
mlirBoolAttrGet(context, schema.is_mutable())));
|
||||
attrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("sigIsVararg"),
|
||||
mlirBoolAttrGet(context, schema.is_vararg())));
|
||||
attrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("sigIsVarret"),
|
||||
mlirBoolAttrGet(context, schema.is_varret())));
|
||||
|
||||
// Arg types.
|
||||
llvm::SmallVector<MlirAttribute, 4> args;
|
||||
|
@ -58,7 +62,8 @@ void KernelCallBuilder::addSchemaAttrs() {
|
|||
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||
}
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("sigArgTypes"), mlirArrayAttrGet(context, args.size(), args.data())));
|
||||
toMlirStringRef("sigArgTypes"),
|
||||
mlirArrayAttrGet(context, args.size(), args.data())));
|
||||
|
||||
// Return types.
|
||||
llvm::SmallVector<MlirAttribute, 4> returns;
|
||||
|
@ -203,14 +208,17 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
|||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||
// (this is fragile and reveals details that are not guaranteed).
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||
funcAttrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("type"),
|
||||
mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context, inputTypes.size(), inputTypes.data(),
|
||||
/*numResults=*/0, /*results=*/nullptr))));
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("type"), mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context, inputTypes.size(), inputTypes.data(),
|
||||
/*numResults=*/0, /*results=*/nullptr))));
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
toMlirStringRef("sym_name"), mlirStringAttrGet(context, name.size(), name.data())));
|
||||
toMlirStringRef("sym_name"),
|
||||
mlirStringAttrGet(context, name.size(), name.data())));
|
||||
|
||||
MlirOperationState state = mlirOperationStateGet(toMlirStringRef("func"), location);
|
||||
MlirOperationState state =
|
||||
mlirOperationStateGet(toMlirStringRef("func"), location);
|
||||
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
||||
{
|
||||
// Don't access these once ownership transferred.
|
||||
|
@ -234,7 +242,8 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
|||
void FuncBuilder::rewriteFuncReturnTypes(
|
||||
llvm::SmallVectorImpl<MlirType> &resultTypes) {
|
||||
// Get inputs from current function type.
|
||||
MlirAttribute funcTypeAttr = mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
||||
MlirAttribute funcTypeAttr =
|
||||
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
||||
assert(!mlirAttributeIsNull(funcTypeAttr) &&
|
||||
"function missing 'type' attribute");
|
||||
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
||||
|
@ -250,7 +259,8 @@ void FuncBuilder::rewriteFuncReturnTypes(
|
|||
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||
resultTypes.size(), resultTypes.data());
|
||||
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
|
||||
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"), newFuncTypeAttr);
|
||||
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
|
||||
newFuncTypeAttr);
|
||||
(void)newFuncTypeAttr;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,8 +24,7 @@ namespace torch_mlir {
|
|||
class OperationStateHolder {
|
||||
public:
|
||||
OperationStateHolder(const char *name, MlirLocation loc)
|
||||
: state(
|
||||
mlirOperationStateGet(toMlirStringRef(name), loc)) {}
|
||||
: state(mlirOperationStateGet(toMlirStringRef(name), loc)) {}
|
||||
OperationStateHolder(const OperationStateHolder &) = delete;
|
||||
OperationStateHolder(OperationStateHolder &&other) = delete;
|
||||
~OperationStateHolder() {
|
||||
|
|
|
@ -22,6 +22,7 @@ def Basicpy_Dialect : Dialect {
|
|||
Core types and ops
|
||||
}];
|
||||
let cppNamespace = "::mlir::NPCOMP::Basicpy";
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -30,8 +31,9 @@ def Basicpy_Dialect : Dialect {
|
|||
|
||||
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Basicpy_Dialect, mnemonic, traits> {
|
||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -138,7 +140,7 @@ def Basicpy_DictType : DialectType<Basicpy_Dialect,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
// Type/attribute predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_SingletonType : AnyTypeOf<[
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionSupport.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
|
|
|
@ -13,6 +13,7 @@ include "BasicpyDialect.td"
|
|||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -93,8 +94,54 @@ def CompareOperationAttr : StrEnumAttr<
|
|||
// Constant and constructor operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
|
||||
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
let summary = "A constant from the Python3 numeric type hierarchy";
|
||||
let description = [{
|
||||
Basicpy re-uses core MLIR types to represent the Python3 numeric type
|
||||
hierarchy with the following mappings:
|
||||
|
||||
* Python3 `int` : In python, this type is signed, arbitrary precision but
|
||||
in typical realizations, it maps to an MLIR `IntegerType` of a fixed
|
||||
bit-width (typically si64 if no further information is known). In the
|
||||
future, there may be a real `Basicpy::IntType` that retains the true
|
||||
arbitrary precision nature, but this is deemed an enhancement that
|
||||
does not obviate the need to infer physical, sized types for many
|
||||
real-world cases. As such, the Basicpy numeric type hierarchy will
|
||||
always include physical `IntegerType`, if only to enable progressive
|
||||
lowering and interop with cases where the precise type is known.
|
||||
* Python3 `float` : This is allowed to map to any legal floating point
|
||||
type on the physical machine and is usually represented as a double (f64).
|
||||
In MLIR, any `FloatType` is allowed, which facilitates progressive
|
||||
lowering and interop with cases where a more precise type is known.
|
||||
* Python3 `complex` : Maps to an MLIR `ComplexType` with a `FloatType`
|
||||
elementType (note: in Python, complex numbers are always defined with
|
||||
floating point components).
|
||||
* `bool` : See `bool_constant` for a constant (i1) -> !basicpy.BoolType
|
||||
constant. This constant op is not used for representing such bool
|
||||
values, even though from the Python perspective, bool is part of the
|
||||
numeric hierarchy (the distinction is really only necessary during
|
||||
promotion).
|
||||
|
||||
### Integer Signedness
|
||||
|
||||
All `int` values in Python are signed. However, there exist special cases
|
||||
where libraries (i.e. struct packing and numpy arrays) interoperate with
|
||||
unsigned values. As such, when mapping to MLIR, Python integer types
|
||||
are represented as either signed or unsigned `IntegerType` types and can
|
||||
be lowered to signless integers as appropriate (typically during realization
|
||||
of arithmetic expressions where the choice is meaningful). Since it is not
|
||||
known at the outset when in lowering this information is safe to discard
|
||||
this `numeric_constant` op accepts any signedness.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyAttr:$value);
|
||||
let results = (outs AnyType);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
let summary = "A boolean constant";
|
||||
let description = [{
|
||||
A constant of type !basicpy.BoolType that can take either an i1 value
|
||||
|
@ -173,7 +220,7 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
|
|||
}
|
||||
|
||||
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
let summary = "Constant bytes value";
|
||||
let description = [{
|
||||
A bytes value of BytesType. The value is represented by a StringAttr.
|
||||
|
@ -204,7 +251,7 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
|
|||
}
|
||||
|
||||
def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
let summary = "Constant string value";
|
||||
let description = [{
|
||||
A string value of StrType. The value is represented by a StringAttr
|
||||
|
@ -224,7 +271,7 @@ def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
|
|||
// Casting and coercion operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Basicpy_AsPredicateValueOp : Basicpy_Op<"as_predicate_value",
|
||||
def Basicpy_AsI1Op : Basicpy_Op<"as_i1",
|
||||
[NoSideEffect]> {
|
||||
let summary = "Evaluates an input to an i1 predicate value";
|
||||
let description = [{
|
||||
|
@ -355,7 +402,6 @@ def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> {
|
|||
StrArrayAttr:$arg_names);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
|
||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins)>,
|
||||
|
@ -427,8 +473,6 @@ def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [
|
|||
let arguments = (ins);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins)>,
|
||||
|
|
|
@ -215,11 +215,11 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Converts the as_predicate_value op for numeric types.
|
||||
class NumericToPredicateValue : public OpRewritePattern<Basicpy::AsPredicateValueOp> {
|
||||
// Converts the as_i1 op for numeric types.
|
||||
class NumericToI1 : public OpRewritePattern<Basicpy::AsI1Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Basicpy::AsPredicateValueOp op,
|
||||
LogicalResult matchAndRewrite(Basicpy::AsI1Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto operandType = op.operand().getType();
|
||||
|
@ -245,5 +245,5 @@ void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns(
|
|||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<NumericBinaryExpr>(context);
|
||||
patterns.insert<NumericCompare>(context);
|
||||
patterns.insert<NumericToPredicateValue>(context);
|
||||
patterns.insert<NumericToI1>(context);
|
||||
}
|
||||
|
|
|
@ -27,6 +27,37 @@ void BasicpyDialect::initialize() {
|
|||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
Operation *BasicpyDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
// NumericConstantOp.
|
||||
// Supports IntegerType (any signedness), FloatType and ComplexType.
|
||||
if (type.isa<IntegerType>() || type.isa<FloatType>() ||
|
||||
type.isa<ComplexType>())
|
||||
return builder.create<NumericConstantOp>(loc, type, value);
|
||||
|
||||
// Bool (i1 -> !basicpy.BoolType).
|
||||
if (type.isa<Basicpy::BoolType>()) {
|
||||
auto i1Value = value.dyn_cast<IntegerAttr>();
|
||||
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
|
||||
return builder.create<BoolConstantOp>(loc, type, i1Value);
|
||||
}
|
||||
|
||||
// Bytes.
|
||||
if (type.isa<Basicpy::BytesType>()) {
|
||||
if (auto strValue = value.dyn_cast<StringAttr>())
|
||||
return builder.create<BytesConstantOp>(loc, type, strValue);
|
||||
}
|
||||
|
||||
// Str.
|
||||
if (type.isa<Basicpy::StrType>()) {
|
||||
if (auto strValue = value.dyn_cast<StringAttr>())
|
||||
return builder.create<StrConstantOp>(loc, type, strValue);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Basicpy {
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::Basicpy;
|
||||
|
||||
// Fallback verifier for ops that don't have a dedicated one.
|
||||
template <typename T> static LogicalResult verify(T op) { return success(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BoolConstantOp
|
||||
|
@ -28,6 +29,11 @@ OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
|
|||
return valueAttr();
|
||||
}
|
||||
|
||||
void BoolConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "bool");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BytesConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -36,6 +42,110 @@ OpFoldResult BytesConstantOp::fold(ArrayRef<Attribute> operands) {
|
|||
return valueAttr();
|
||||
}
|
||||
|
||||
void BytesConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "bytes");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NumericConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseNumericConstantOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
Attribute valueAttr;
|
||||
if (parser.parseOptionalAttrDict(result->attributes) ||
|
||||
parser.parseAttribute(valueAttr, "value", result->attributes))
|
||||
return failure();
|
||||
|
||||
// If not an Integer or Float attr (which carry the type in the attr),
|
||||
// expect a trailing type.
|
||||
Type type;
|
||||
if (valueAttr.isa<IntegerAttr>() || valueAttr.isa<FloatAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser.parseColonType(type))
|
||||
return failure();
|
||||
return parser.addTypeToList(type, result->types);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, NumericConstantOp op) {
|
||||
p << "basicpy.numeric_constant ";
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
|
||||
|
||||
if (op.getAttrs().size() > 1)
|
||||
p << ' ';
|
||||
p << op.value();
|
||||
|
||||
// If not an Integer or Float attr, expect a trailing type.
|
||||
if (!op.value().isa<IntegerAttr>() && !op.value().isa<FloatAttr>())
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(NumericConstantOp &op) {
|
||||
auto value = op.value();
|
||||
if (!value)
|
||||
return op.emitOpError("requires a 'value' attribute");
|
||||
auto type = op.getType();
|
||||
|
||||
if (type.isa<FloatType>()) {
|
||||
if (!value.isa<FloatAttr>())
|
||||
return op.emitOpError("requires 'value' to be a floating point constant");
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||
if (!value.isa<IntegerAttr>())
|
||||
return op.emitOpError("requires 'value' to be an integer constant");
|
||||
if (intType.getWidth() == 1)
|
||||
return op.emitOpError("cannot have an i1 type");
|
||||
return success();
|
||||
}
|
||||
|
||||
if (type.isa<ComplexType>()) {
|
||||
if (auto complexComps = value.dyn_cast<ArrayAttr>()) {
|
||||
if (complexComps.size() == 2) {
|
||||
auto realValue = complexComps[0].dyn_cast<FloatAttr>();
|
||||
auto imagValue = complexComps[1].dyn_cast<FloatAttr>();
|
||||
if (realValue && imagValue &&
|
||||
realValue.getType() == imagValue.getType())
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return op.emitOpError("requires 'value' to be a two element array of "
|
||||
"floating point complex number components");
|
||||
}
|
||||
|
||||
return op.emitOpError("unsupported basicpy.numeric_constant type");
|
||||
}
|
||||
|
||||
OpFoldResult NumericConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "numeric_constant has no operands");
|
||||
return value();
|
||||
}
|
||||
|
||||
void NumericConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
Type type = getType();
|
||||
if (auto intCst = value().dyn_cast<IntegerAttr>()) {
|
||||
IntegerType intTy = type.dyn_cast<IntegerType>();
|
||||
APInt intValue = intCst.getValue();
|
||||
|
||||
// Otherwise, build a complex name with the value and type.
|
||||
SmallString<32> specialNameBuffer;
|
||||
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
||||
specialName << "num";
|
||||
if (intTy.isSigned())
|
||||
specialName << intValue.getSExtValue();
|
||||
else
|
||||
specialName << intValue.getZExtValue();
|
||||
if (intTy)
|
||||
specialName << '_' << type;
|
||||
setNameFn(getResult(), specialName.str());
|
||||
} else {
|
||||
setNameFn(getResult(), "num");
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExecOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -54,7 +164,7 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static void printExecOp(OpAsmPrinter &p, ExecOp op) {
|
||||
static void print(OpAsmPrinter &p, ExecOp op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||
p.printRegion(op.body());
|
||||
|
@ -64,7 +174,7 @@ static void printExecOp(OpAsmPrinter &p, ExecOp op) {
|
|||
// FuncTemplateCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) {
|
||||
static LogicalResult verify(FuncTemplateCallOp op) {
|
||||
auto argNames = op.arg_names();
|
||||
if (argNames.size() > op.args().size()) {
|
||||
return op.emitOpError() << "expected <= kw arg names vs args";
|
||||
|
@ -108,7 +218,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
|
||||
static void print(OpAsmPrinter &p, FuncTemplateOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
p.printSymbolName(op.getName());
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
|
@ -116,7 +226,7 @@ static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
|
|||
p.printRegion(op.body());
|
||||
}
|
||||
|
||||
static LogicalResult verifyBasicpyOp(FuncTemplateOp op) {
|
||||
static LogicalResult verify(FuncTemplateOp op) {
|
||||
Block *body = op.getBody();
|
||||
for (auto &childOp : body->getOperations()) {
|
||||
if (!llvm::isa<FuncOp>(childOp) &&
|
||||
|
@ -151,7 +261,7 @@ static ParseResult parseSlotObjectMakeOp(OpAsmParser &parser,
|
|||
parser.getNameLoc(), result->operands);
|
||||
}
|
||||
|
||||
static void printSlotObjectMakeOp(OpAsmPrinter &p, SlotObjectMakeOp op) {
|
||||
static void print(OpAsmPrinter &p, SlotObjectMakeOp op) {
|
||||
// If the argument types do not match the result type slots, then
|
||||
// print the generic form.
|
||||
auto canCustomPrint = ([&]() -> bool {
|
||||
|
@ -218,7 +328,7 @@ static ParseResult parseSlotObjectGetOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static void printSlotObjectGetOp(OpAsmPrinter &p, SlotObjectGetOp op) {
|
||||
static void print(OpAsmPrinter &p, SlotObjectGetOp op) {
|
||||
// If the argument types do not match the result type slots, then
|
||||
// print the generic form.
|
||||
auto canCustomPrint = ([&]() -> bool {
|
||||
|
@ -262,6 +372,11 @@ OpFoldResult StrConstantOp::fold(ArrayRef<Attribute> operands) {
|
|||
return valueAttr();
|
||||
}
|
||||
|
||||
void StrConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "str");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnknownCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -287,9 +402,5 @@ void UnknownCastOp::getCanonicalizationPatterns(
|
|||
patterns.insert<ElideIdentityUnknownCast>(context);
|
||||
}
|
||||
|
||||
} // namespace Basicpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"
|
||||
|
|
|
@ -342,7 +342,7 @@ public:
|
|||
op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) {
|
||||
if (auto op = dyn_cast<AsI1Op>(childOp)) {
|
||||
// Note that the result is always i1 and not subject to type
|
||||
// inference.
|
||||
equations.getTypeNode(op.operand());
|
||||
|
|
|
@ -140,7 +140,7 @@ public:
|
|||
// addSubtypeConstraint(op.false_value(), op.true_value(), op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) {
|
||||
if (auto op = dyn_cast<AsI1Op>(childOp)) {
|
||||
// Note that the result is always i1 and not subject to type
|
||||
// inference.
|
||||
resolveValueType(op.operand());
|
||||
|
|
|
@ -255,7 +255,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
next_value = self.sub_evaluate(next_node)
|
||||
if not next_nodes:
|
||||
return next_value
|
||||
condition_value = ir_h.basicpy_as_predicate_value_op(next_value).result
|
||||
condition_value = ir_h.basicpy_as_i1_op(next_value).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||
condition_value, True)
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
|
@ -347,8 +347,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
|
||||
def visit_IfExp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
test_result = ir_h.basicpy_as_predicate_value_op(self.sub_evaluate(
|
||||
ast_node.test)).result
|
||||
test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||
test_result, True)
|
||||
|
||||
|
@ -386,7 +385,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
operand_value = self.sub_evaluate(ast_node.operand)
|
||||
if isinstance(op, ast.Not):
|
||||
# Special handling for logical-not.
|
||||
condition_value = ir_h.basicpy_as_predicate_value_op(operand_value).result
|
||||
condition_value = ir_h.basicpy_as_i1_op(operand_value).result
|
||||
true_value = ir_h.basicpy_bool_constant_op(True).result
|
||||
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||
self.value = ir_h.select_op(condition_value, false_value,
|
||||
|
|
|
@ -90,8 +90,8 @@ class DialectHelper(_BaseDialectHelper):
|
|||
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
||||
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
||||
|
||||
def basicpy_as_predicate_value_op(self, value):
|
||||
return self.op("basicpy.as_predicate_value", [self.i1_type], [value])
|
||||
def basicpy_as_i1_op(self, value):
|
||||
return self.op("basicpy.as_i1", [self.i1_type], [value])
|
||||
|
||||
def basicpy_unknown_cast_op(self, result_type, operand):
|
||||
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
||||
|
|
|
@ -7,9 +7,66 @@ func @unknown_cast_elide(%arg0 : i32) -> i32 {
|
|||
return %0 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @unknown_cast_preserve
|
||||
func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType {
|
||||
// CHECK: basicpy.unknown_cast
|
||||
%0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @numeric_constant_si32
|
||||
func @numeric_constant_si32() -> si32 {
|
||||
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
|
||||
%0 = basicpy.numeric_constant -1 : si32
|
||||
return %0 : si32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @numeric_constant_ui32
|
||||
func @numeric_constant_ui32() -> ui32 {
|
||||
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
|
||||
%0 = basicpy.numeric_constant 1 : ui32
|
||||
return %0 : ui32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @numeric_constant_f32
|
||||
func @numeric_constant_f32() -> f32 {
|
||||
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
|
||||
%0 = basicpy.numeric_constant 2.0 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @numeric_constant_complex_f32
|
||||
func @numeric_constant_complex_f32() -> complex<f32> {
|
||||
// CHECK: %num = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
||||
%0 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @bool_constant
|
||||
func @bool_constant() -> !basicpy.BoolType {
|
||||
// CHECK: %bool = basicpy.bool_constant true
|
||||
%0 = basicpy.bool_constant true
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @bytes_constant
|
||||
func @bytes_constant() -> !basicpy.BytesType {
|
||||
// CHECK: %bytes = basicpy.bytes_constant "foobar"
|
||||
%0 = basicpy.bytes_constant "foobar"
|
||||
return %0 : !basicpy.BytesType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @str_constant
|
||||
func @str_constant() -> !basicpy.StrType {
|
||||
// CHECK: %str = basicpy.str_constant "foobar"
|
||||
%0 = basicpy.str_constant "foobar"
|
||||
return %0 : !basicpy.StrType
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s
|
||||
|
||||
func @numeric_constant_string_attr() {
|
||||
// expected-error @+1 {{op requires 'value' to be an integer constant}}
|
||||
%0 = "basicpy.numeric_constant"() {value="somestring" : i32} : () -> (i32)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_bool() {
|
||||
// expected-error @+1 {{cannot have an i1 type}}
|
||||
%0 = "basicpy.numeric_constant"() {value = true} : () -> (i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_mismatch_int() {
|
||||
// expected-error @+1 {{op requires 'value' to be a floating point constant}}
|
||||
%0 = "basicpy.numeric_constant"() {value = 1 : i32} : () -> (f64)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_mismatch_float() {
|
||||
// expected-error @+1 {{op requires 'value' to be an integer constant}}
|
||||
%0 = "basicpy.numeric_constant"() {value = 1.0 : f32} : () -> (i32)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_complex_wrong_arity() {
|
||||
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||
%3 = basicpy.numeric_constant [2.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_complex_mismatch_type_real() {
|
||||
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||
%3 = basicpy.numeric_constant [2.0 : f64, 3.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func @numeric_constant_complex_mismatch_type_imag() {
|
||||
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f16] : complex<f32>
|
||||
return
|
||||
}
|
|
@ -24,4 +24,16 @@ func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType {
|
|||
return %0 : !basicpy.TupleType
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @numeric_constant
|
||||
func @numeric_constant() {
|
||||
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
|
||||
%0 = basicpy.numeric_constant -1 : si32
|
||||
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
|
||||
%1 = basicpy.numeric_constant 1 : ui32
|
||||
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
|
||||
%2 = basicpy.numeric_constant 2.0 : f32
|
||||
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
||||
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -14,9 +14,9 @@ def logical_and():
|
|||
x = 1
|
||||
y = 0
|
||||
z = 2
|
||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]]
|
||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
|
||||
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]]
|
||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
|
||||
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
|
||||
# CHECK: scf.yield %[[ZCAST]]
|
||||
|
@ -39,12 +39,12 @@ def logical_or():
|
|||
# CHECK: %[[X:.*]] = constant 0
|
||||
# CHECK: %[[Y:.*]] = constant 1
|
||||
# CHECK: %[[Z:.*]] = constant 2
|
||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]]
|
||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
|
||||
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
|
||||
# CHECK: scf.yield %[[XCAST]]
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]]
|
||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
|
||||
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
|
||||
# CHECK: scf.yield %[[YCAST]]
|
||||
|
@ -68,7 +68,7 @@ def logical_not():
|
|||
x = 1
|
||||
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
|
||||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]]
|
||||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
|
||||
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||
return not x
|
||||
|
||||
|
@ -78,7 +78,7 @@ def logical_not():
|
|||
def conditional():
|
||||
# CHECK: %[[X:.*]] = constant 1
|
||||
x = 1
|
||||
# CHECK: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]]
|
||||
# CHECK: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
|
||||
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[TWO:.*]] = constant 2 : i64
|
||||
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]
|
||||
|
|
Loading…
Reference in New Issue