From 3937dd14cb41a3bfb6bbddbd056502d902ee2bac Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 23 Nov 2020 19:20:26 -0800 Subject: [PATCH] Add basicpy.numeric_constant op. * Going through TODOs on the PyTorch side, this is a big cause of them (not being able to have constants for signed/unsigned). * Added complex while in here since we're at the phase where it is better to just have things complete than partially done. --- .../pytorch/csrc/builder/func_builder.cpp | 42 ++++-- frontends/pytorch/csrc/builder/func_builder.h | 3 +- .../Dialect/Basicpy/IR/BasicpyDialect.td | 8 +- .../npcomp/Dialect/Basicpy/IR/BasicpyOps.h | 1 + .../npcomp/Dialect/Basicpy/IR/BasicpyOps.td | 58 +++++++- .../BasicpyToStd/PrimitiveOpsConversion.cpp | 8 +- lib/Dialect/Basicpy/IR/BasicpyDialect.cpp | 31 ++++ lib/Dialect/Basicpy/IR/BasicpyOps.cpp | 139 ++++++++++++++++-- .../Basicpy/Transforms/TypeInference.cpp | 2 +- lib/Typing/Transforms/CPATypeInference.cpp | 2 +- python/npcomp/compiler/numpy/importer.py | 7 +- python/npcomp/dialect/Basicpy.py | 4 +- test/Dialect/Basicpy/canonicalize.mlir | 57 +++++++ test/Dialect/Basicpy/ops-invalid.mlir | 49 ++++++ test/Dialect/Basicpy/ops.mlir | 14 +- test/Python/Compiler/Numpy/booleans.py | 12 +- 16 files changed, 376 insertions(+), 61 deletions(-) create mode 100644 test/Dialect/Basicpy/ops-invalid.mlir diff --git a/frontends/pytorch/csrc/builder/func_builder.cpp b/frontends/pytorch/csrc/builder/func_builder.cpp index a7967e31d..888a67721 100644 --- a/frontends/pytorch/csrc/builder/func_builder.cpp +++ b/frontends/pytorch/csrc/builder/func_builder.cpp @@ -17,7 +17,8 @@ using namespace torch_mlir; static MlirOperation createStandardConstant(MlirLocation loc, MlirType type, MlirAttribute value) { OperationStateHolder s("std.constant", loc); - MlirNamedAttribute valueAttr = mlirNamedAttributeGet(toMlirStringRef("value"), value); + MlirNamedAttribute valueAttr = + mlirNamedAttributeGet(toMlirStringRef("value"), value); mlirOperationStateAddResults(s, 1, &type); mlirOperationStateAddAttributes(s, 1, &valueAttr); return s.createOperation(); @@ -44,12 +45,15 @@ void KernelCallBuilder::addSchemaAttrs() { // sigIsVarret // sigIsMutable llvm::SmallVector attrs; - attrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("sigIsMutable"), mlirBoolAttrGet(context, schema.is_mutable()))); - attrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("sigIsVararg"), mlirBoolAttrGet(context, schema.is_vararg()))); - attrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("sigIsVarret"), mlirBoolAttrGet(context, schema.is_varret()))); + attrs.push_back( + mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"), + mlirBoolAttrGet(context, schema.is_mutable()))); + attrs.push_back( + mlirNamedAttributeGet(toMlirStringRef("sigIsVararg"), + mlirBoolAttrGet(context, schema.is_vararg()))); + attrs.push_back( + mlirNamedAttributeGet(toMlirStringRef("sigIsVarret"), + mlirBoolAttrGet(context, schema.is_varret()))); // Arg types. llvm::SmallVector args; @@ -58,7 +62,8 @@ void KernelCallBuilder::addSchemaAttrs() { args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data())); } attrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("sigArgTypes"), mlirArrayAttrGet(context, args.size(), args.data()))); + toMlirStringRef("sigArgTypes"), + mlirArrayAttrGet(context, args.size(), args.data()))); // Return types. llvm::SmallVector returns; @@ -203,14 +208,17 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, // TODO: Create a dedicated API upstream for creating/manipulating func ops. // (this is fragile and reveals details that are not guaranteed). llvm::SmallVector funcAttrs; + funcAttrs.push_back( + mlirNamedAttributeGet(toMlirStringRef("type"), + mlirTypeAttrGet(mlirFunctionTypeGet( + context, inputTypes.size(), inputTypes.data(), + /*numResults=*/0, /*results=*/nullptr)))); funcAttrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("type"), mlirTypeAttrGet(mlirFunctionTypeGet( - context, inputTypes.size(), inputTypes.data(), - /*numResults=*/0, /*results=*/nullptr)))); - funcAttrs.push_back(mlirNamedAttributeGet( - toMlirStringRef("sym_name"), mlirStringAttrGet(context, name.size(), name.data()))); + toMlirStringRef("sym_name"), + mlirStringAttrGet(context, name.size(), name.data()))); - MlirOperationState state = mlirOperationStateGet(toMlirStringRef("func"), location); + MlirOperationState state = + mlirOperationStateGet(toMlirStringRef("func"), location); mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data()); { // Don't access these once ownership transferred. @@ -234,7 +242,8 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, void FuncBuilder::rewriteFuncReturnTypes( llvm::SmallVectorImpl &resultTypes) { // Get inputs from current function type. - MlirAttribute funcTypeAttr = mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type")); + MlirAttribute funcTypeAttr = + mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type")); assert(!mlirAttributeIsNull(funcTypeAttr) && "function missing 'type' attribute"); assert(mlirAttributeIsAType(funcTypeAttr) && @@ -250,7 +259,8 @@ void FuncBuilder::rewriteFuncReturnTypes( mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), resultTypes.size(), resultTypes.data()); MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType); - mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"), newFuncTypeAttr); + mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"), + newFuncTypeAttr); (void)newFuncTypeAttr; } diff --git a/frontends/pytorch/csrc/builder/func_builder.h b/frontends/pytorch/csrc/builder/func_builder.h index fc38591d0..e8c7a42c2 100644 --- a/frontends/pytorch/csrc/builder/func_builder.h +++ b/frontends/pytorch/csrc/builder/func_builder.h @@ -24,8 +24,7 @@ namespace torch_mlir { class OperationStateHolder { public: OperationStateHolder(const char *name, MlirLocation loc) - : state( - mlirOperationStateGet(toMlirStringRef(name), loc)) {} + : state(mlirOperationStateGet(toMlirStringRef(name), loc)) {} OperationStateHolder(const OperationStateHolder &) = delete; OperationStateHolder(OperationStateHolder &&other) = delete; ~OperationStateHolder() { diff --git a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td index 23c7a9acf..1b1df7bdc 100644 --- a/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td +++ b/include/npcomp/Dialect/Basicpy/IR/BasicpyDialect.td @@ -22,6 +22,7 @@ def Basicpy_Dialect : Dialect { Core types and ops }]; let cppNamespace = "::mlir::NPCOMP::Basicpy"; + let hasConstantMaterializer = 1; } //===----------------------------------------------------------------------===// @@ -30,8 +31,9 @@ def Basicpy_Dialect : Dialect { class Basicpy_Op traits = []> : Op { - let parser = [{ return parse$cppClass(parser, &result); }]; - let printer = [{ return print$cppClass(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, &result); }]; + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; } //===----------------------------------------------------------------------===// @@ -138,7 +140,7 @@ def Basicpy_DictType : DialectType]> { + let summary = "A constant from the Python3 numeric type hierarchy"; + let description = [{ + Basicpy re-uses core MLIR types to represent the Python3 numeric type + hierarchy with the following mappings: + + * Python3 `int` : In python, this type is signed, arbitrary precision but + in typical realizations, it maps to an MLIR `IntegerType` of a fixed + bit-width (typically si64 if no further information is known). In the + future, there may be a real `Basicpy::IntType` that retains the true + arbitrary precision nature, but this is deemed an enhancement that + does not obviate the need to infer physical, sized types for many + real-world cases. As such, the Basicpy numeric type hierarchy will + always include physical `IntegerType`, if only to enable progressive + lowering and interop with cases where the precise type is known. + * Python3 `float` : This is allowed to map to any legal floating point + type on the physical machine and is usually represented as a double (f64). + In MLIR, any `FloatType` is allowed, which facilitates progressive + lowering and interop with cases where a more precise type is known. + * Python3 `complex` : Maps to an MLIR `ComplexType` with a `FloatType` + elementType (note: in Python, complex numbers are always defined with + floating point components). + * `bool` : See `bool_constant` for a constant (i1) -> !basicpy.BoolType + constant. This constant op is not used for representing such bool + values, even though from the Python perspective, bool is part of the + numeric hierarchy (the distinction is really only necessary during + promotion). + + ### Integer Signedness + + All `int` values in Python are signed. However, there exist special cases + where libraries (i.e. struct packing and numpy arrays) interoperate with + unsigned values. As such, when mapping to MLIR, Python integer types + are represented as either signed or unsigned `IntegerType` types and can + be lowered to signless integers as appropriate (typically during realization + of arithmetic expressions where the choice is meaningful). Since it is not + known at the outset when in lowering this information is safe to discard + this `numeric_constant` op accepts any signedness. + }]; + + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + let hasFolder = 1; +} + def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [ - ConstantLike, NoSideEffect]> { + ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "A boolean constant"; let description = [{ A constant of type !basicpy.BoolType that can take either an i1 value @@ -173,7 +220,7 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> { } def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [ - ConstantLike, NoSideEffect]> { + ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Constant bytes value"; let description = [{ A bytes value of BytesType. The value is represented by a StringAttr. @@ -204,7 +251,7 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [ } def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [ - ConstantLike, NoSideEffect]> { + ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Constant string value"; let description = [{ A string value of StrType. The value is represented by a StringAttr @@ -224,7 +271,7 @@ def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [ // Casting and coercion operations //===----------------------------------------------------------------------===// -def Basicpy_AsPredicateValueOp : Basicpy_Op<"as_predicate_value", +def Basicpy_AsI1Op : Basicpy_Op<"as_i1", [NoSideEffect]> { let summary = "Evaluates an input to an i1 predicate value"; let description = [{ @@ -355,7 +402,6 @@ def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> { StrArrayAttr:$arg_names); let results = (outs AnyType:$result); let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)"; - let verifier = [{ return verifyBasicpyOp(*this); }]; let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins)>, @@ -427,8 +473,6 @@ def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [ let arguments = (ins); let regions = (region SizedRegion<1>:$body); - let verifier = [{ return verifyBasicpyOp(*this); }]; - let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins)>, diff --git a/lib/Conversion/BasicpyToStd/PrimitiveOpsConversion.cpp b/lib/Conversion/BasicpyToStd/PrimitiveOpsConversion.cpp index 5f442bf05..691791eb7 100644 --- a/lib/Conversion/BasicpyToStd/PrimitiveOpsConversion.cpp +++ b/lib/Conversion/BasicpyToStd/PrimitiveOpsConversion.cpp @@ -215,11 +215,11 @@ public: } }; -// Converts the as_predicate_value op for numeric types. -class NumericToPredicateValue : public OpRewritePattern { +// Converts the as_i1 op for numeric types. +class NumericToI1 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Basicpy::AsPredicateValueOp op, + LogicalResult matchAndRewrite(Basicpy::AsI1Op op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto operandType = op.operand().getType(); @@ -245,5 +245,5 @@ void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert(context); patterns.insert(context); - patterns.insert(context); + patterns.insert(context); } diff --git a/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp b/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp index 221381374..4e5845ad0 100644 --- a/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp +++ b/lib/Dialect/Basicpy/IR/BasicpyDialect.cpp @@ -27,6 +27,37 @@ void BasicpyDialect::initialize() { allowUnknownOperations(); } +Operation *BasicpyDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + // NumericConstantOp. + // Supports IntegerType (any signedness), FloatType and ComplexType. + if (type.isa() || type.isa() || + type.isa()) + return builder.create(loc, type, value); + + // Bool (i1 -> !basicpy.BoolType). + if (type.isa()) { + auto i1Value = value.dyn_cast(); + if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1) + return builder.create(loc, type, i1Value); + } + + // Bytes. + if (type.isa()) { + if (auto strValue = value.dyn_cast()) + return builder.create(loc, type, strValue); + } + + // Str. + if (type.isa()) { + if (auto strValue = value.dyn_cast()) + return builder.create(loc, type, strValue); + } + + return nullptr; +} + Type BasicpyDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) diff --git a/lib/Dialect/Basicpy/IR/BasicpyOps.cpp b/lib/Dialect/Basicpy/IR/BasicpyOps.cpp index bf4c25579..7d741f37d 100644 --- a/lib/Dialect/Basicpy/IR/BasicpyOps.cpp +++ b/lib/Dialect/Basicpy/IR/BasicpyOps.cpp @@ -13,12 +13,13 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" - #include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc" -namespace mlir { -namespace NPCOMP { -namespace Basicpy { +using namespace mlir; +using namespace mlir::NPCOMP::Basicpy; + +// Fallback verifier for ops that don't have a dedicated one. +template static LogicalResult verify(T op) { return success(); } //===----------------------------------------------------------------------===// // BoolConstantOp @@ -28,6 +29,11 @@ OpFoldResult BoolConstantOp::fold(ArrayRef operands) { return valueAttr(); } +void BoolConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "bool"); +} + //===----------------------------------------------------------------------===// // BytesConstantOp //===----------------------------------------------------------------------===// @@ -36,6 +42,110 @@ OpFoldResult BytesConstantOp::fold(ArrayRef operands) { return valueAttr(); } +void BytesConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "bytes"); +} + +//===----------------------------------------------------------------------===// +// NumericConstantOp +//===----------------------------------------------------------------------===// + +static ParseResult parseNumericConstantOp(OpAsmParser &parser, + OperationState *result) { + Attribute valueAttr; + if (parser.parseOptionalAttrDict(result->attributes) || + parser.parseAttribute(valueAttr, "value", result->attributes)) + return failure(); + + // If not an Integer or Float attr (which carry the type in the attr), + // expect a trailing type. + Type type; + if (valueAttr.isa() || valueAttr.isa()) + type = valueAttr.getType(); + else if (parser.parseColonType(type)) + return failure(); + return parser.addTypeToList(type, result->types); +} + +static void print(OpAsmPrinter &p, NumericConstantOp op) { + p << "basicpy.numeric_constant "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); + + if (op.getAttrs().size() > 1) + p << ' '; + p << op.value(); + + // If not an Integer or Float attr, expect a trailing type. + if (!op.value().isa() && !op.value().isa()) + p << " : " << op.getType(); +} + +static LogicalResult verify(NumericConstantOp &op) { + auto value = op.value(); + if (!value) + return op.emitOpError("requires a 'value' attribute"); + auto type = op.getType(); + + if (type.isa()) { + if (!value.isa()) + return op.emitOpError("requires 'value' to be a floating point constant"); + return success(); + } + + if (auto intType = type.dyn_cast()) { + if (!value.isa()) + return op.emitOpError("requires 'value' to be an integer constant"); + if (intType.getWidth() == 1) + return op.emitOpError("cannot have an i1 type"); + return success(); + } + + if (type.isa()) { + if (auto complexComps = value.dyn_cast()) { + if (complexComps.size() == 2) { + auto realValue = complexComps[0].dyn_cast(); + auto imagValue = complexComps[1].dyn_cast(); + if (realValue && imagValue && + realValue.getType() == imagValue.getType()) + return success(); + } + } + return op.emitOpError("requires 'value' to be a two element array of " + "floating point complex number components"); + } + + return op.emitOpError("unsupported basicpy.numeric_constant type"); +} + +OpFoldResult NumericConstantOp::fold(ArrayRef operands) { + assert(operands.empty() && "numeric_constant has no operands"); + return value(); +} + +void NumericConstantOp::getAsmResultNames( + function_ref setNameFn) { + Type type = getType(); + if (auto intCst = value().dyn_cast()) { + IntegerType intTy = type.dyn_cast(); + APInt intValue = intCst.getValue(); + + // Otherwise, build a complex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "num"; + if (intTy.isSigned()) + specialName << intValue.getSExtValue(); + else + specialName << intValue.getZExtValue(); + if (intTy) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + } else { + setNameFn(getResult(), "num"); + } +} + //===----------------------------------------------------------------------===// // ExecOp //===----------------------------------------------------------------------===// @@ -54,7 +164,7 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) { return success(); } -static void printExecOp(OpAsmPrinter &p, ExecOp op) { +static void print(OpAsmPrinter &p, ExecOp op) { p << op.getOperationName(); p.printOptionalAttrDictWithKeyword(op.getAttrs()); p.printRegion(op.body()); @@ -64,7 +174,7 @@ static void printExecOp(OpAsmPrinter &p, ExecOp op) { // FuncTemplateCallOp //===----------------------------------------------------------------------===// -static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) { +static LogicalResult verify(FuncTemplateCallOp op) { auto argNames = op.arg_names(); if (argNames.size() > op.args().size()) { return op.emitOpError() << "expected <= kw arg names vs args"; @@ -108,7 +218,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser, return success(); } -static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) { +static void print(OpAsmPrinter &p, FuncTemplateOp op) { p << op.getOperationName() << " "; p.printSymbolName(op.getName()); p.printOptionalAttrDictWithKeyword(op.getAttrs(), @@ -116,7 +226,7 @@ static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) { p.printRegion(op.body()); } -static LogicalResult verifyBasicpyOp(FuncTemplateOp op) { +static LogicalResult verify(FuncTemplateOp op) { Block *body = op.getBody(); for (auto &childOp : body->getOperations()) { if (!llvm::isa(childOp) && @@ -151,7 +261,7 @@ static ParseResult parseSlotObjectMakeOp(OpAsmParser &parser, parser.getNameLoc(), result->operands); } -static void printSlotObjectMakeOp(OpAsmPrinter &p, SlotObjectMakeOp op) { +static void print(OpAsmPrinter &p, SlotObjectMakeOp op) { // If the argument types do not match the result type slots, then // print the generic form. auto canCustomPrint = ([&]() -> bool { @@ -218,7 +328,7 @@ static ParseResult parseSlotObjectGetOp(OpAsmParser &parser, return success(); } -static void printSlotObjectGetOp(OpAsmPrinter &p, SlotObjectGetOp op) { +static void print(OpAsmPrinter &p, SlotObjectGetOp op) { // If the argument types do not match the result type slots, then // print the generic form. auto canCustomPrint = ([&]() -> bool { @@ -262,6 +372,11 @@ OpFoldResult StrConstantOp::fold(ArrayRef operands) { return valueAttr(); } +void StrConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "str"); +} + //===----------------------------------------------------------------------===// // UnknownCastOp //===----------------------------------------------------------------------===// @@ -287,9 +402,5 @@ void UnknownCastOp::getCanonicalizationPatterns( patterns.insert(context); } -} // namespace Basicpy -} // namespace NPCOMP -} // namespace mlir - #define GET_OP_CLASSES #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc" diff --git a/lib/Dialect/Basicpy/Transforms/TypeInference.cpp b/lib/Dialect/Basicpy/Transforms/TypeInference.cpp index 81d87eb8a..43b0ef099 100644 --- a/lib/Dialect/Basicpy/Transforms/TypeInference.cpp +++ b/lib/Dialect/Basicpy/Transforms/TypeInference.cpp @@ -342,7 +342,7 @@ public: op); return WalkResult::advance(); } - if (auto op = dyn_cast(childOp)) { + if (auto op = dyn_cast(childOp)) { // Note that the result is always i1 and not subject to type // inference. equations.getTypeNode(op.operand()); diff --git a/lib/Typing/Transforms/CPATypeInference.cpp b/lib/Typing/Transforms/CPATypeInference.cpp index 2487c6eec..a848399fd 100644 --- a/lib/Typing/Transforms/CPATypeInference.cpp +++ b/lib/Typing/Transforms/CPATypeInference.cpp @@ -140,7 +140,7 @@ public: // addSubtypeConstraint(op.false_value(), op.true_value(), op); return WalkResult::advance(); } - if (auto op = dyn_cast(childOp)) { + if (auto op = dyn_cast(childOp)) { // Note that the result is always i1 and not subject to type // inference. resolveValueType(op.operand()); diff --git a/python/npcomp/compiler/numpy/importer.py b/python/npcomp/compiler/numpy/importer.py index 569f09c6a..b060a3feb 100644 --- a/python/npcomp/compiler/numpy/importer.py +++ b/python/npcomp/compiler/numpy/importer.py @@ -255,7 +255,7 @@ class ExpressionImporter(BaseNodeVisitor): next_value = self.sub_evaluate(next_node) if not next_nodes: return next_value - condition_value = ir_h.basicpy_as_predicate_value_op(next_value).result + condition_value = ir_h.basicpy_as_i1_op(next_value).result if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType], condition_value, True) orig_ip = ir_h.builder.insertion_point @@ -347,8 +347,7 @@ class ExpressionImporter(BaseNodeVisitor): def visit_IfExp(self, ast_node): ir_h = self.fctx.ir_h - test_result = ir_h.basicpy_as_predicate_value_op(self.sub_evaluate( - ast_node.test)).result + test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType], test_result, True) @@ -386,7 +385,7 @@ class ExpressionImporter(BaseNodeVisitor): operand_value = self.sub_evaluate(ast_node.operand) if isinstance(op, ast.Not): # Special handling for logical-not. - condition_value = ir_h.basicpy_as_predicate_value_op(operand_value).result + condition_value = ir_h.basicpy_as_i1_op(operand_value).result true_value = ir_h.basicpy_bool_constant_op(True).result false_value = ir_h.basicpy_bool_constant_op(False).result self.value = ir_h.select_op(condition_value, false_value, diff --git a/python/npcomp/dialect/Basicpy.py b/python/npcomp/dialect/Basicpy.py index 6ce11751f..7fa717180 100644 --- a/python/npcomp/dialect/Basicpy.py +++ b/python/npcomp/dialect/Basicpy.py @@ -90,8 +90,8 @@ class DialectHelper(_BaseDialectHelper): attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))}) return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs) - def basicpy_as_predicate_value_op(self, value): - return self.op("basicpy.as_predicate_value", [self.i1_type], [value]) + def basicpy_as_i1_op(self, value): + return self.op("basicpy.as_i1", [self.i1_type], [value]) def basicpy_unknown_cast_op(self, result_type, operand): return self.op("basicpy.unknown_cast", [result_type], [operand]) diff --git a/test/Dialect/Basicpy/canonicalize.mlir b/test/Dialect/Basicpy/canonicalize.mlir index 3be2ed341..80943db7c 100644 --- a/test/Dialect/Basicpy/canonicalize.mlir +++ b/test/Dialect/Basicpy/canonicalize.mlir @@ -7,9 +7,66 @@ func @unknown_cast_elide(%arg0 : i32) -> i32 { return %0 : i32 } +// ----- // CHECK-LABEL: func @unknown_cast_preserve func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType { // CHECK: basicpy.unknown_cast %0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType return %0 : !basicpy.UnknownType } + +// ----- +// CHECK-LABEL: @numeric_constant_si32 +func @numeric_constant_si32() -> si32 { + // CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32 + %0 = basicpy.numeric_constant -1 : si32 + return %0 : si32 +} + +// ----- +// CHECK-LABEL: @numeric_constant_ui32 +func @numeric_constant_ui32() -> ui32 { + // CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32 + %0 = basicpy.numeric_constant 1 : ui32 + return %0 : ui32 +} + +// ----- +// CHECK-LABEL: @numeric_constant_f32 +func @numeric_constant_f32() -> f32 { + // CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32 + %0 = basicpy.numeric_constant 2.0 : f32 + return %0 : f32 +} + +// ----- +// CHECK-LABEL: @numeric_constant_complex_f32 +func @numeric_constant_complex_f32() -> complex { + // CHECK: %num = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex + %0 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex + return %0 : complex +} + +// ----- +// CHECK-LABEL: @bool_constant +func @bool_constant() -> !basicpy.BoolType { + // CHECK: %bool = basicpy.bool_constant true + %0 = basicpy.bool_constant true + return %0 : !basicpy.BoolType +} + +// ----- +// CHECK-LABEL: @bytes_constant +func @bytes_constant() -> !basicpy.BytesType { + // CHECK: %bytes = basicpy.bytes_constant "foobar" + %0 = basicpy.bytes_constant "foobar" + return %0 : !basicpy.BytesType +} + +// ----- +// CHECK-LABEL: @str_constant +func @str_constant() -> !basicpy.StrType { + // CHECK: %str = basicpy.str_constant "foobar" + %0 = basicpy.str_constant "foobar" + return %0 : !basicpy.StrType +} diff --git a/test/Dialect/Basicpy/ops-invalid.mlir b/test/Dialect/Basicpy/ops-invalid.mlir new file mode 100644 index 000000000..416c364a1 --- /dev/null +++ b/test/Dialect/Basicpy/ops-invalid.mlir @@ -0,0 +1,49 @@ +// RUN: npcomp-opt -split-input-file -verify-diagnostics %s + +func @numeric_constant_string_attr() { + // expected-error @+1 {{op requires 'value' to be an integer constant}} + %0 = "basicpy.numeric_constant"() {value="somestring" : i32} : () -> (i32) + return +} + +// ----- +func @numeric_constant_bool() { + // expected-error @+1 {{cannot have an i1 type}} + %0 = "basicpy.numeric_constant"() {value = true} : () -> (i1) + return +} + +// ----- +func @numeric_constant_mismatch_int() { + // expected-error @+1 {{op requires 'value' to be a floating point constant}} + %0 = "basicpy.numeric_constant"() {value = 1 : i32} : () -> (f64) + return +} + +// ----- +func @numeric_constant_mismatch_float() { + // expected-error @+1 {{op requires 'value' to be an integer constant}} + %0 = "basicpy.numeric_constant"() {value = 1.0 : f32} : () -> (i32) + return +} + +// ----- +func @numeric_constant_complex_wrong_arity() { + // expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}} + %3 = basicpy.numeric_constant [2.0 : f32] : complex + return +} + +// ----- +func @numeric_constant_complex_mismatch_type_real() { + // expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}} + %3 = basicpy.numeric_constant [2.0 : f64, 3.0 : f32] : complex + return +} + +// ----- +func @numeric_constant_complex_mismatch_type_imag() { + // expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}} + %3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f16] : complex + return +} diff --git a/test/Dialect/Basicpy/ops.mlir b/test/Dialect/Basicpy/ops.mlir index 1856a302e..bb5fa63c0 100644 --- a/test/Dialect/Basicpy/ops.mlir +++ b/test/Dialect/Basicpy/ops.mlir @@ -24,4 +24,16 @@ func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType { return %0 : !basicpy.TupleType } - +// ----- +// CHECK-LABEL: @numeric_constant +func @numeric_constant() { + // CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32 + %0 = basicpy.numeric_constant -1 : si32 + // CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32 + %1 = basicpy.numeric_constant 1 : ui32 + // CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32 + %2 = basicpy.numeric_constant 2.0 : f32 + // CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex + %3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex + return +} diff --git a/test/Python/Compiler/Numpy/booleans.py b/test/Python/Compiler/Numpy/booleans.py index da3882ad2..7c76eadff 100644 --- a/test/Python/Compiler/Numpy/booleans.py +++ b/test/Python/Compiler/Numpy/booleans.py @@ -14,9 +14,9 @@ def logical_and(): x = 1 y = 0 z = 2 - # CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]] + # CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]] # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { - # CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]] + # CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]] # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]] # CHECK: scf.yield %[[ZCAST]] @@ -39,12 +39,12 @@ def logical_or(): # CHECK: %[[X:.*]] = constant 0 # CHECK: %[[Y:.*]] = constant 1 # CHECK: %[[Z:.*]] = constant 2 - # CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]] + # CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]] # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]] # CHECK: scf.yield %[[XCAST]] # CHECK: } else { - # CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]] + # CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]] # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { # CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]] # CHECK: scf.yield %[[YCAST]] @@ -68,7 +68,7 @@ def logical_not(): x = 1 # CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true # CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false - # CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]] + # CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]] # CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType return not x @@ -78,7 +78,7 @@ def logical_not(): def conditional(): # CHECK: %[[X:.*]] = constant 1 x = 1 - # CHECK: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]] + # CHECK: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]] # CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) { # CHECK: %[[TWO:.*]] = constant 2 : i64 # CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]