Add basicpy.numeric_constant op.

* Going through TODOs on the PyTorch side, this is a big cause of them (not being able to have constants for signed/unsigned).
* Added complex while in here since we're at the phase where it is better to just have things complete than partially done.
pull/128/head
Stella Laurenzo 2020-11-23 19:20:26 -08:00
parent bea0af419d
commit 3937dd14cb
16 changed files with 376 additions and 61 deletions

View File

@ -17,7 +17,8 @@ using namespace torch_mlir;
static MlirOperation createStandardConstant(MlirLocation loc, MlirType type, static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
MlirAttribute value) { MlirAttribute value) {
OperationStateHolder s("std.constant", loc); OperationStateHolder s("std.constant", loc);
MlirNamedAttribute valueAttr = mlirNamedAttributeGet(toMlirStringRef("value"), value); MlirNamedAttribute valueAttr =
mlirNamedAttributeGet(toMlirStringRef("value"), value);
mlirOperationStateAddResults(s, 1, &type); mlirOperationStateAddResults(s, 1, &type);
mlirOperationStateAddAttributes(s, 1, &valueAttr); mlirOperationStateAddAttributes(s, 1, &valueAttr);
return s.createOperation(); return s.createOperation();
@ -44,12 +45,15 @@ void KernelCallBuilder::addSchemaAttrs() {
// sigIsVarret // sigIsVarret
// sigIsMutable // sigIsMutable
llvm::SmallVector<MlirNamedAttribute, 8> attrs; llvm::SmallVector<MlirNamedAttribute, 8> attrs;
attrs.push_back(mlirNamedAttributeGet( attrs.push_back(
toMlirStringRef("sigIsMutable"), mlirBoolAttrGet(context, schema.is_mutable()))); mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"),
attrs.push_back(mlirNamedAttributeGet( mlirBoolAttrGet(context, schema.is_mutable())));
toMlirStringRef("sigIsVararg"), mlirBoolAttrGet(context, schema.is_vararg()))); attrs.push_back(
attrs.push_back(mlirNamedAttributeGet( mlirNamedAttributeGet(toMlirStringRef("sigIsVararg"),
toMlirStringRef("sigIsVarret"), mlirBoolAttrGet(context, schema.is_varret()))); mlirBoolAttrGet(context, schema.is_vararg())));
attrs.push_back(
mlirNamedAttributeGet(toMlirStringRef("sigIsVarret"),
mlirBoolAttrGet(context, schema.is_varret())));
// Arg types. // Arg types.
llvm::SmallVector<MlirAttribute, 4> args; llvm::SmallVector<MlirAttribute, 4> args;
@ -58,7 +62,8 @@ void KernelCallBuilder::addSchemaAttrs() {
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data())); args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
} }
attrs.push_back(mlirNamedAttributeGet( attrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("sigArgTypes"), mlirArrayAttrGet(context, args.size(), args.data()))); toMlirStringRef("sigArgTypes"),
mlirArrayAttrGet(context, args.size(), args.data())));
// Return types. // Return types.
llvm::SmallVector<MlirAttribute, 4> returns; llvm::SmallVector<MlirAttribute, 4> returns;
@ -203,14 +208,17 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
// TODO: Create a dedicated API upstream for creating/manipulating func ops. // TODO: Create a dedicated API upstream for creating/manipulating func ops.
// (this is fragile and reveals details that are not guaranteed). // (this is fragile and reveals details that are not guaranteed).
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs; llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
funcAttrs.push_back(
mlirNamedAttributeGet(toMlirStringRef("type"),
mlirTypeAttrGet(mlirFunctionTypeGet(
context, inputTypes.size(), inputTypes.data(),
/*numResults=*/0, /*results=*/nullptr))));
funcAttrs.push_back(mlirNamedAttributeGet( funcAttrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("type"), mlirTypeAttrGet(mlirFunctionTypeGet( toMlirStringRef("sym_name"),
context, inputTypes.size(), inputTypes.data(), mlirStringAttrGet(context, name.size(), name.data())));
/*numResults=*/0, /*results=*/nullptr))));
funcAttrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("sym_name"), mlirStringAttrGet(context, name.size(), name.data())));
MlirOperationState state = mlirOperationStateGet(toMlirStringRef("func"), location); MlirOperationState state =
mlirOperationStateGet(toMlirStringRef("func"), location);
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data()); mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
{ {
// Don't access these once ownership transferred. // Don't access these once ownership transferred.
@ -234,7 +242,8 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
void FuncBuilder::rewriteFuncReturnTypes( void FuncBuilder::rewriteFuncReturnTypes(
llvm::SmallVectorImpl<MlirType> &resultTypes) { llvm::SmallVectorImpl<MlirType> &resultTypes) {
// Get inputs from current function type. // Get inputs from current function type.
MlirAttribute funcTypeAttr = mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type")); MlirAttribute funcTypeAttr =
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
assert(!mlirAttributeIsNull(funcTypeAttr) && assert(!mlirAttributeIsNull(funcTypeAttr) &&
"function missing 'type' attribute"); "function missing 'type' attribute");
assert(mlirAttributeIsAType(funcTypeAttr) && assert(mlirAttributeIsAType(funcTypeAttr) &&
@ -250,7 +259,8 @@ void FuncBuilder::rewriteFuncReturnTypes(
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
resultTypes.size(), resultTypes.data()); resultTypes.size(), resultTypes.data());
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType); MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"), newFuncTypeAttr); mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
newFuncTypeAttr);
(void)newFuncTypeAttr; (void)newFuncTypeAttr;
} }

View File

@ -24,8 +24,7 @@ namespace torch_mlir {
class OperationStateHolder { class OperationStateHolder {
public: public:
OperationStateHolder(const char *name, MlirLocation loc) OperationStateHolder(const char *name, MlirLocation loc)
: state( : state(mlirOperationStateGet(toMlirStringRef(name), loc)) {}
mlirOperationStateGet(toMlirStringRef(name), loc)) {}
OperationStateHolder(const OperationStateHolder &) = delete; OperationStateHolder(const OperationStateHolder &) = delete;
OperationStateHolder(OperationStateHolder &&other) = delete; OperationStateHolder(OperationStateHolder &&other) = delete;
~OperationStateHolder() { ~OperationStateHolder() {

View File

@ -22,6 +22,7 @@ def Basicpy_Dialect : Dialect {
Core types and ops Core types and ops
}]; }];
let cppNamespace = "::mlir::NPCOMP::Basicpy"; let cppNamespace = "::mlir::NPCOMP::Basicpy";
let hasConstantMaterializer = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -30,8 +31,9 @@ def Basicpy_Dialect : Dialect {
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> : class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Basicpy_Dialect, mnemonic, traits> { Op<Basicpy_Dialect, mnemonic, traits> {
let parser = [{ return parse$cppClass(parser, &result); }]; let parser = [{ return ::parse$cppClass(parser, &result); }];
let printer = [{ return print$cppClass(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -138,7 +140,7 @@ def Basicpy_DictType : DialectType<Basicpy_Dialect,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type predicates // Type/attribute predicates
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def Basicpy_SingletonType : AnyTypeOf<[ def Basicpy_SingletonType : AnyTypeOf<[

View File

@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h"

View File

@ -13,6 +13,7 @@ include "BasicpyDialect.td"
include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -93,8 +94,54 @@ def CompareOperationAttr : StrEnumAttr<
// Constant and constructor operations // Constant and constructor operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "A constant from the Python3 numeric type hierarchy";
let description = [{
Basicpy re-uses core MLIR types to represent the Python3 numeric type
hierarchy with the following mappings:
* Python3 `int` : In python, this type is signed, arbitrary precision but
in typical realizations, it maps to an MLIR `IntegerType` of a fixed
bit-width (typically si64 if no further information is known). In the
future, there may be a real `Basicpy::IntType` that retains the true
arbitrary precision nature, but this is deemed an enhancement that
does not obviate the need to infer physical, sized types for many
real-world cases. As such, the Basicpy numeric type hierarchy will
always include physical `IntegerType`, if only to enable progressive
lowering and interop with cases where the precise type is known.
* Python3 `float` : This is allowed to map to any legal floating point
type on the physical machine and is usually represented as a double (f64).
In MLIR, any `FloatType` is allowed, which facilitates progressive
lowering and interop with cases where a more precise type is known.
* Python3 `complex` : Maps to an MLIR `ComplexType` with a `FloatType`
elementType (note: in Python, complex numbers are always defined with
floating point components).
* `bool` : See `bool_constant` for a constant (i1) -> !basicpy.BoolType
constant. This constant op is not used for representing such bool
values, even though from the Python perspective, bool is part of the
numeric hierarchy (the distinction is really only necessary during
promotion).
### Integer Signedness
All `int` values in Python are signed. However, there exist special cases
where libraries (i.e. struct packing and numpy arrays) interoperate with
unsigned values. As such, when mapping to MLIR, Python integer types
are represented as either signed or unsigned `IntegerType` types and can
be lowered to signless integers as appropriate (typically during realization
of arithmetic expressions where the choice is meaningful). Since it is not
known at the outset when in lowering this information is safe to discard
this `numeric_constant` op accepts any signedness.
}];
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType);
let hasFolder = 1;
}
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [ def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
ConstantLike, NoSideEffect]> { ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "A boolean constant"; let summary = "A boolean constant";
let description = [{ let description = [{
A constant of type !basicpy.BoolType that can take either an i1 value A constant of type !basicpy.BoolType that can take either an i1 value
@ -173,7 +220,7 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
} }
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [ def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
ConstantLike, NoSideEffect]> { ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "Constant bytes value"; let summary = "Constant bytes value";
let description = [{ let description = [{
A bytes value of BytesType. The value is represented by a StringAttr. A bytes value of BytesType. The value is represented by a StringAttr.
@ -204,7 +251,7 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
} }
def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [ def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
ConstantLike, NoSideEffect]> { ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "Constant string value"; let summary = "Constant string value";
let description = [{ let description = [{
A string value of StrType. The value is represented by a StringAttr A string value of StrType. The value is represented by a StringAttr
@ -224,7 +271,7 @@ def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
// Casting and coercion operations // Casting and coercion operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def Basicpy_AsPredicateValueOp : Basicpy_Op<"as_predicate_value", def Basicpy_AsI1Op : Basicpy_Op<"as_i1",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "Evaluates an input to an i1 predicate value"; let summary = "Evaluates an input to an i1 predicate value";
let description = [{ let description = [{
@ -355,7 +402,6 @@ def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> {
StrArrayAttr:$arg_names); StrArrayAttr:$arg_names);
let results = (outs AnyType:$result); let results = (outs AnyType:$result);
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)"; let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
let verifier = [{ return verifyBasicpyOp(*this); }];
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilderDAG<(ins)>, OpBuilderDAG<(ins)>,
@ -427,8 +473,6 @@ def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [
let arguments = (ins); let arguments = (ins);
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let verifier = [{ return verifyBasicpyOp(*this); }];
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilderDAG<(ins)>, OpBuilderDAG<(ins)>,

View File

@ -215,11 +215,11 @@ public:
} }
}; };
// Converts the as_predicate_value op for numeric types. // Converts the as_i1 op for numeric types.
class NumericToPredicateValue : public OpRewritePattern<Basicpy::AsPredicateValueOp> { class NumericToI1 : public OpRewritePattern<Basicpy::AsI1Op> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Basicpy::AsPredicateValueOp op, LogicalResult matchAndRewrite(Basicpy::AsI1Op op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operandType = op.operand().getType(); auto operandType = op.operand().getType();
@ -245,5 +245,5 @@ void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) { MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<NumericBinaryExpr>(context); patterns.insert<NumericBinaryExpr>(context);
patterns.insert<NumericCompare>(context); patterns.insert<NumericCompare>(context);
patterns.insert<NumericToPredicateValue>(context); patterns.insert<NumericToI1>(context);
} }

View File

@ -27,6 +27,37 @@ void BasicpyDialect::initialize() {
allowUnknownOperations(); allowUnknownOperations();
} }
Operation *BasicpyDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// NumericConstantOp.
// Supports IntegerType (any signedness), FloatType and ComplexType.
if (type.isa<IntegerType>() || type.isa<FloatType>() ||
type.isa<ComplexType>())
return builder.create<NumericConstantOp>(loc, type, value);
// Bool (i1 -> !basicpy.BoolType).
if (type.isa<Basicpy::BoolType>()) {
auto i1Value = value.dyn_cast<IntegerAttr>();
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
return builder.create<BoolConstantOp>(loc, type, i1Value);
}
// Bytes.
if (type.isa<Basicpy::BytesType>()) {
if (auto strValue = value.dyn_cast<StringAttr>())
return builder.create<BytesConstantOp>(loc, type, strValue);
}
// Str.
if (type.isa<Basicpy::StrType>()) {
if (auto strValue = value.dyn_cast<StringAttr>())
return builder.create<StrConstantOp>(loc, type, strValue);
}
return nullptr;
}
Type BasicpyDialect::parseType(DialectAsmParser &parser) const { Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword; StringRef keyword;
if (parser.parseKeyword(&keyword)) if (parser.parseKeyword(&keyword))

View File

@ -13,12 +13,13 @@
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc" #include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc"
namespace mlir { using namespace mlir;
namespace NPCOMP { using namespace mlir::NPCOMP::Basicpy;
namespace Basicpy {
// Fallback verifier for ops that don't have a dedicated one.
template <typename T> static LogicalResult verify(T op) { return success(); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BoolConstantOp // BoolConstantOp
@ -28,6 +29,11 @@ OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
return valueAttr(); return valueAttr();
} }
void BoolConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "bool");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BytesConstantOp // BytesConstantOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -36,6 +42,110 @@ OpFoldResult BytesConstantOp::fold(ArrayRef<Attribute> operands) {
return valueAttr(); return valueAttr();
} }
void BytesConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "bytes");
}
//===----------------------------------------------------------------------===//
// NumericConstantOp
//===----------------------------------------------------------------------===//
static ParseResult parseNumericConstantOp(OpAsmParser &parser,
OperationState *result) {
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result->attributes) ||
parser.parseAttribute(valueAttr, "value", result->attributes))
return failure();
// If not an Integer or Float attr (which carry the type in the attr),
// expect a trailing type.
Type type;
if (valueAttr.isa<IntegerAttr>() || valueAttr.isa<FloatAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
return parser.addTypeToList(type, result->types);
}
static void print(OpAsmPrinter &p, NumericConstantOp op) {
p << "basicpy.numeric_constant ";
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
if (op.getAttrs().size() > 1)
p << ' ';
p << op.value();
// If not an Integer or Float attr, expect a trailing type.
if (!op.value().isa<IntegerAttr>() && !op.value().isa<FloatAttr>())
p << " : " << op.getType();
}
static LogicalResult verify(NumericConstantOp &op) {
auto value = op.value();
if (!value)
return op.emitOpError("requires a 'value' attribute");
auto type = op.getType();
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return op.emitOpError("requires 'value' to be a floating point constant");
return success();
}
if (auto intType = type.dyn_cast<IntegerType>()) {
if (!value.isa<IntegerAttr>())
return op.emitOpError("requires 'value' to be an integer constant");
if (intType.getWidth() == 1)
return op.emitOpError("cannot have an i1 type");
return success();
}
if (type.isa<ComplexType>()) {
if (auto complexComps = value.dyn_cast<ArrayAttr>()) {
if (complexComps.size() == 2) {
auto realValue = complexComps[0].dyn_cast<FloatAttr>();
auto imagValue = complexComps[1].dyn_cast<FloatAttr>();
if (realValue && imagValue &&
realValue.getType() == imagValue.getType())
return success();
}
}
return op.emitOpError("requires 'value' to be a two element array of "
"floating point complex number components");
}
return op.emitOpError("unsupported basicpy.numeric_constant type");
}
OpFoldResult NumericConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "numeric_constant has no operands");
return value();
}
void NumericConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
Type type = getType();
if (auto intCst = value().dyn_cast<IntegerAttr>()) {
IntegerType intTy = type.dyn_cast<IntegerType>();
APInt intValue = intCst.getValue();
// Otherwise, build a complex name with the value and type.
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "num";
if (intTy.isSigned())
specialName << intValue.getSExtValue();
else
specialName << intValue.getZExtValue();
if (intTy)
specialName << '_' << type;
setNameFn(getResult(), specialName.str());
} else {
setNameFn(getResult(), "num");
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ExecOp // ExecOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -54,7 +164,7 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
return success(); return success();
} }
static void printExecOp(OpAsmPrinter &p, ExecOp op) { static void print(OpAsmPrinter &p, ExecOp op) {
p << op.getOperationName(); p << op.getOperationName();
p.printOptionalAttrDictWithKeyword(op.getAttrs()); p.printOptionalAttrDictWithKeyword(op.getAttrs());
p.printRegion(op.body()); p.printRegion(op.body());
@ -64,7 +174,7 @@ static void printExecOp(OpAsmPrinter &p, ExecOp op) {
// FuncTemplateCallOp // FuncTemplateCallOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) { static LogicalResult verify(FuncTemplateCallOp op) {
auto argNames = op.arg_names(); auto argNames = op.arg_names();
if (argNames.size() > op.args().size()) { if (argNames.size() > op.args().size()) {
return op.emitOpError() << "expected <= kw arg names vs args"; return op.emitOpError() << "expected <= kw arg names vs args";
@ -108,7 +218,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
return success(); return success();
} }
static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) { static void print(OpAsmPrinter &p, FuncTemplateOp op) {
p << op.getOperationName() << " "; p << op.getOperationName() << " ";
p.printSymbolName(op.getName()); p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(op.getAttrs(), p.printOptionalAttrDictWithKeyword(op.getAttrs(),
@ -116,7 +226,7 @@ static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
p.printRegion(op.body()); p.printRegion(op.body());
} }
static LogicalResult verifyBasicpyOp(FuncTemplateOp op) { static LogicalResult verify(FuncTemplateOp op) {
Block *body = op.getBody(); Block *body = op.getBody();
for (auto &childOp : body->getOperations()) { for (auto &childOp : body->getOperations()) {
if (!llvm::isa<FuncOp>(childOp) && if (!llvm::isa<FuncOp>(childOp) &&
@ -151,7 +261,7 @@ static ParseResult parseSlotObjectMakeOp(OpAsmParser &parser,
parser.getNameLoc(), result->operands); parser.getNameLoc(), result->operands);
} }
static void printSlotObjectMakeOp(OpAsmPrinter &p, SlotObjectMakeOp op) { static void print(OpAsmPrinter &p, SlotObjectMakeOp op) {
// If the argument types do not match the result type slots, then // If the argument types do not match the result type slots, then
// print the generic form. // print the generic form.
auto canCustomPrint = ([&]() -> bool { auto canCustomPrint = ([&]() -> bool {
@ -218,7 +328,7 @@ static ParseResult parseSlotObjectGetOp(OpAsmParser &parser,
return success(); return success();
} }
static void printSlotObjectGetOp(OpAsmPrinter &p, SlotObjectGetOp op) { static void print(OpAsmPrinter &p, SlotObjectGetOp op) {
// If the argument types do not match the result type slots, then // If the argument types do not match the result type slots, then
// print the generic form. // print the generic form.
auto canCustomPrint = ([&]() -> bool { auto canCustomPrint = ([&]() -> bool {
@ -262,6 +372,11 @@ OpFoldResult StrConstantOp::fold(ArrayRef<Attribute> operands) {
return valueAttr(); return valueAttr();
} }
void StrConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "str");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// UnknownCastOp // UnknownCastOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -287,9 +402,5 @@ void UnknownCastOp::getCanonicalizationPatterns(
patterns.insert<ElideIdentityUnknownCast>(context); patterns.insert<ElideIdentityUnknownCast>(context);
} }
} // namespace Basicpy
} // namespace NPCOMP
} // namespace mlir
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"

View File

@ -342,7 +342,7 @@ public:
op); op);
return WalkResult::advance(); return WalkResult::advance();
} }
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) { if (auto op = dyn_cast<AsI1Op>(childOp)) {
// Note that the result is always i1 and not subject to type // Note that the result is always i1 and not subject to type
// inference. // inference.
equations.getTypeNode(op.operand()); equations.getTypeNode(op.operand());

View File

@ -140,7 +140,7 @@ public:
// addSubtypeConstraint(op.false_value(), op.true_value(), op); // addSubtypeConstraint(op.false_value(), op.true_value(), op);
return WalkResult::advance(); return WalkResult::advance();
} }
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) { if (auto op = dyn_cast<AsI1Op>(childOp)) {
// Note that the result is always i1 and not subject to type // Note that the result is always i1 and not subject to type
// inference. // inference.
resolveValueType(op.operand()); resolveValueType(op.operand());

View File

@ -255,7 +255,7 @@ class ExpressionImporter(BaseNodeVisitor):
next_value = self.sub_evaluate(next_node) next_value = self.sub_evaluate(next_node)
if not next_nodes: if not next_nodes:
return next_value return next_value
condition_value = ir_h.basicpy_as_predicate_value_op(next_value).result condition_value = ir_h.basicpy_as_i1_op(next_value).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType], if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
condition_value, True) condition_value, True)
orig_ip = ir_h.builder.insertion_point orig_ip = ir_h.builder.insertion_point
@ -347,8 +347,7 @@ class ExpressionImporter(BaseNodeVisitor):
def visit_IfExp(self, ast_node): def visit_IfExp(self, ast_node):
ir_h = self.fctx.ir_h ir_h = self.fctx.ir_h
test_result = ir_h.basicpy_as_predicate_value_op(self.sub_evaluate( test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result
ast_node.test)).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType], if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
test_result, True) test_result, True)
@ -386,7 +385,7 @@ class ExpressionImporter(BaseNodeVisitor):
operand_value = self.sub_evaluate(ast_node.operand) operand_value = self.sub_evaluate(ast_node.operand)
if isinstance(op, ast.Not): if isinstance(op, ast.Not):
# Special handling for logical-not. # Special handling for logical-not.
condition_value = ir_h.basicpy_as_predicate_value_op(operand_value).result condition_value = ir_h.basicpy_as_i1_op(operand_value).result
true_value = ir_h.basicpy_bool_constant_op(True).result true_value = ir_h.basicpy_bool_constant_op(True).result
false_value = ir_h.basicpy_bool_constant_op(False).result false_value = ir_h.basicpy_bool_constant_op(False).result
self.value = ir_h.select_op(condition_value, false_value, self.value = ir_h.select_op(condition_value, false_value,

View File

@ -90,8 +90,8 @@ class DialectHelper(_BaseDialectHelper):
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))}) attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs) return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
def basicpy_as_predicate_value_op(self, value): def basicpy_as_i1_op(self, value):
return self.op("basicpy.as_predicate_value", [self.i1_type], [value]) return self.op("basicpy.as_i1", [self.i1_type], [value])
def basicpy_unknown_cast_op(self, result_type, operand): def basicpy_unknown_cast_op(self, result_type, operand):
return self.op("basicpy.unknown_cast", [result_type], [operand]) return self.op("basicpy.unknown_cast", [result_type], [operand])

View File

@ -7,9 +7,66 @@ func @unknown_cast_elide(%arg0 : i32) -> i32 {
return %0 : i32 return %0 : i32
} }
// -----
// CHECK-LABEL: func @unknown_cast_preserve // CHECK-LABEL: func @unknown_cast_preserve
func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType { func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType {
// CHECK: basicpy.unknown_cast // CHECK: basicpy.unknown_cast
%0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType %0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType return %0 : !basicpy.UnknownType
} }
// -----
// CHECK-LABEL: @numeric_constant_si32
func @numeric_constant_si32() -> si32 {
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
%0 = basicpy.numeric_constant -1 : si32
return %0 : si32
}
// -----
// CHECK-LABEL: @numeric_constant_ui32
func @numeric_constant_ui32() -> ui32 {
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
%0 = basicpy.numeric_constant 1 : ui32
return %0 : ui32
}
// -----
// CHECK-LABEL: @numeric_constant_f32
func @numeric_constant_f32() -> f32 {
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
%0 = basicpy.numeric_constant 2.0 : f32
return %0 : f32
}
// -----
// CHECK-LABEL: @numeric_constant_complex_f32
func @numeric_constant_complex_f32() -> complex<f32> {
// CHECK: %num = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
%0 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
return %0 : complex<f32>
}
// -----
// CHECK-LABEL: @bool_constant
func @bool_constant() -> !basicpy.BoolType {
// CHECK: %bool = basicpy.bool_constant true
%0 = basicpy.bool_constant true
return %0 : !basicpy.BoolType
}
// -----
// CHECK-LABEL: @bytes_constant
func @bytes_constant() -> !basicpy.BytesType {
// CHECK: %bytes = basicpy.bytes_constant "foobar"
%0 = basicpy.bytes_constant "foobar"
return %0 : !basicpy.BytesType
}
// -----
// CHECK-LABEL: @str_constant
func @str_constant() -> !basicpy.StrType {
// CHECK: %str = basicpy.str_constant "foobar"
%0 = basicpy.str_constant "foobar"
return %0 : !basicpy.StrType
}

View File

@ -0,0 +1,49 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s
func @numeric_constant_string_attr() {
// expected-error @+1 {{op requires 'value' to be an integer constant}}
%0 = "basicpy.numeric_constant"() {value="somestring" : i32} : () -> (i32)
return
}
// -----
func @numeric_constant_bool() {
// expected-error @+1 {{cannot have an i1 type}}
%0 = "basicpy.numeric_constant"() {value = true} : () -> (i1)
return
}
// -----
func @numeric_constant_mismatch_int() {
// expected-error @+1 {{op requires 'value' to be a floating point constant}}
%0 = "basicpy.numeric_constant"() {value = 1 : i32} : () -> (f64)
return
}
// -----
func @numeric_constant_mismatch_float() {
// expected-error @+1 {{op requires 'value' to be an integer constant}}
%0 = "basicpy.numeric_constant"() {value = 1.0 : f32} : () -> (i32)
return
}
// -----
func @numeric_constant_complex_wrong_arity() {
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
%3 = basicpy.numeric_constant [2.0 : f32] : complex<f32>
return
}
// -----
func @numeric_constant_complex_mismatch_type_real() {
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
%3 = basicpy.numeric_constant [2.0 : f64, 3.0 : f32] : complex<f32>
return
}
// -----
func @numeric_constant_complex_mismatch_type_imag() {
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f16] : complex<f32>
return
}

View File

@ -24,4 +24,16 @@ func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType {
return %0 : !basicpy.TupleType return %0 : !basicpy.TupleType
} }
// -----
// CHECK-LABEL: @numeric_constant
func @numeric_constant() {
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
%0 = basicpy.numeric_constant -1 : si32
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
%1 = basicpy.numeric_constant 1 : ui32
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
%2 = basicpy.numeric_constant 2.0 : f32
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
return
}

View File

@ -14,9 +14,9 @@ def logical_and():
x = 1 x = 1
y = 0 y = 0
z = 2 z = 2
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]] # CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]] # CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]] # CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
# CHECK: scf.yield %[[ZCAST]] # CHECK: scf.yield %[[ZCAST]]
@ -39,12 +39,12 @@ def logical_or():
# CHECK: %[[X:.*]] = constant 0 # CHECK: %[[X:.*]] = constant 0
# CHECK: %[[Y:.*]] = constant 1 # CHECK: %[[Y:.*]] = constant 1
# CHECK: %[[Z:.*]] = constant 2 # CHECK: %[[Z:.*]] = constant 2
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]] # CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]] # CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
# CHECK: scf.yield %[[XCAST]] # CHECK: scf.yield %[[XCAST]]
# CHECK: } else { # CHECK: } else {
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]] # CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]] # CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
# CHECK: scf.yield %[[YCAST]] # CHECK: scf.yield %[[YCAST]]
@ -68,7 +68,7 @@ def logical_not():
x = 1 x = 1
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true # CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false # CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]] # CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType # CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
return not x return not x
@ -78,7 +78,7 @@ def logical_not():
def conditional(): def conditional():
# CHECK: %[[X:.*]] = constant 1 # CHECK: %[[X:.*]] = constant 1
x = 1 x = 1
# CHECK: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]] # CHECK: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) { # CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
# CHECK: %[[TWO:.*]] = constant 2 : i64 # CHECK: %[[TWO:.*]] = constant 2 : i64
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]] # CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]