diff --git a/include/npcomp/Dialect/Basicpy/BasicpyOps.td b/include/npcomp/Dialect/Basicpy/BasicpyOps.td index 38e90158a..f203413fd 100644 --- a/include/npcomp/Dialect/Basicpy/BasicpyOps.td +++ b/include/npcomp/Dialect/Basicpy/BasicpyOps.td @@ -240,6 +240,19 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [ let assemblyFormat = "attr-dict `:` type($result)"; } +def Basicpy_ToBooleanOp : Basicpy_Op<"to_boolean", [NoSideEffect]> { + let summary = "Evaluates an input to an i1 boolean value"; + let description = [{ + Applies the rules for interpreting a type as a boolean, returning an i1 + indicating the truthiness of the operand. Since the output of this op + is intended to drive lower-level control flow, the i1 type is used (not + the user level BoolType). + }]; + let arguments = (ins AnyType:$operand); + let results = (outs I1:$result); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + def Basicpy_UnknownCastOp : Basicpy_Op<"unknown_cast", [NoSideEffect]> { let summary = "Casts to and from the UnknownType"; let arguments = (ins AnyType:$input); diff --git a/pytest/Compiler/booleans.py b/pytest/Compiler/booleans.py new file mode 100644 index 000000000..49a1bbb8b --- /dev/null +++ b/pytest/Compiler/booleans.py @@ -0,0 +1,77 @@ +# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail + +from npcomp.compiler.frontend import * + + +def import_global(f): + fe = ImportFrontend() + fe.import_global_function(f) + print("// -----") + print(fe.ir_module.to_asm()) + return f + + +# CHECK-LABEL: func @logical_and +@import_global +def logical_and(): + # CHECK: %[[X:.*]] = constant 1 + # CHECK: %[[Y:.*]] = constant 0 + # CHECK: %[[Z:.*]] = constant 2 + x = 1 + y = 0 + z = 2 + # CHECK: %[[XBOOL:.*]] = basicpy.to_boolean %[[X]] + # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { + # CHECK: %[[YBOOL:.*]] = basicpy.to_boolean %[[Y]] + # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { + # CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]] + # CHECK: scf.yield %[[ZCAST]] + # CHECK: } else { + # CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]] + # CHECK: scf.yield %[[YCAST]] + # CHECK: } + # CHECK: %[[IF1CAST:.*]] = basicpy.unknown_cast %[[IF1]] + # CHECK: scf.yield %[[IF1CAST]] + # CHECK: } else { + # CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]] + # CHECK: scf.yield %[[XCAST]] + # CHECK: } + return x and y and z + +# CHECK-LABEL: func @logical_or +@import_global +def logical_or(): + # CHECK: %[[X:.*]] = constant 0 + # CHECK: %[[Y:.*]] = constant 1 + # CHECK: %[[Z:.*]] = constant 2 + # CHECK: %[[XBOOL:.*]] = basicpy.to_boolean %[[X]] + # CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) { + # CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]] + # CHECK: scf.yield %[[XCAST]] + # CHECK: } else { + # CHECK: %[[YBOOL:.*]] = basicpy.to_boolean %[[Y]] + # CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) { + # CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]] + # CHECK: scf.yield %[[YCAST]] + # CHECK: } else { + # CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]] + # CHECK: scf.yield %[[ZCAST]] + # CHECK: } + # CHECK: %[[IF1CAST:.*]] = basicpy.unknown_cast %[[IF1]] + # CHECK: scf.yield %[[IF1CAST]] + # CHECK: } + x = 0 + y = 1 + z = 2 + return x or y or z + +# CHECK-LABEL: func @logical_not +@import_global +def logical_not(): + # CHECK: %[[X:.*]] = constant 1 + x = 1 + # CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant 1 + # CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant 0 + # CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]] + # CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType + return not x diff --git a/python/npcomp/compiler/frontend.py b/python/npcomp/compiler/frontend.py index 1ea27a148..2d74f6e51 100644 --- a/python/npcomp/compiler/frontend.py +++ b/python/npcomp/compiler/frontend.py @@ -237,6 +237,41 @@ class ExpressionImporter(BaseNodeVisitor): ir_h.basicpy_UnknownType, left, right, ast_node.op.__class__.__name__).result + def visit_BoolOp(self, ast_node): + ir_h = self.fctx.ir_h + if isinstance(ast_node.op, ast.And): + return_first_true = False + elif isinstance(ast_node.op, ast.Or): + return_first_true = True + else: + self.fctx.abort("unknown bool op %r" % (ast.dump(ast_node.op))) + + def emit_next(next_nodes): + next_node = next_nodes[0] + next_nodes = next_nodes[1:] + next_value = self.sub_evaluate(next_node) + if not next_nodes: + return next_value + condition_value = ir_h.basicpy_to_boolean_op(next_value).result + if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType], + condition_value, True) + orig_ip = ir_h.builder.insertion_point + # Short-circuit return case. + ir_h.builder.insertion_point = then_ip if return_first_true else else_ip + next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType, + next_value).result + ir_h.scf_yield_op([next_value_casted]) + # Nested evaluate next case. + ir_h.builder.insertion_point = else_ip if return_first_true else then_ip + nested_value = emit_next(next_nodes) + nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op( + ir_h.basicpy_UnknownType, nested_value).result + ir_h.scf_yield_op([nested_value_casted]) + ir_h.builder.insertion_point = orig_ip + return if_op.result + + self.value = emit_next(ast_node.values) + def visit_Compare(self, ast_node): # Short-circuit comparison (degenerates to binary comparison when just # two operands). @@ -284,6 +319,20 @@ class ExpressionImporter(BaseNodeVisitor): self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id) self.value = value + def visit_UnaryOp(self, ast_node): + ir_h = self.fctx.ir_h + op = ast_node.op + operand_value = self.sub_evaluate(ast_node.operand) + if isinstance(op, ast.Not): + # Special handling for logical-not. + condition_value = ir_h.basicpy_to_boolean_op(operand_value).result + true_value = ir_h.basicpy_bool_constant_op(True).result + false_value = ir_h.basicpy_bool_constant_op(False).result + self.value = ir_h.select_op(condition_value, false_value, + true_value).result + else: + self.fctx.abort("Unknown unary op %r", (ast.dump(op))) + if sys.version_info < (3, 8, 0): # <3.8 breaks these out into separate AST classes. def visit_Num(self, ast_node): diff --git a/python/npcomp/dialect/Basicpy.py b/python/npcomp/dialect/Basicpy.py index 996c05a0f..6f65a149c 100644 --- a/python/npcomp/dialect/Basicpy.py +++ b/python/npcomp/dialect/Basicpy.py @@ -90,6 +90,9 @@ class DialectHelper(_BaseDialectHelper): attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))}) return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs) + def basicpy_to_boolean_op(self, value): + return self.op("basicpy.to_boolean", [self.i1_type], [value]) + def basicpy_unknown_cast_op(self, result_type, operand): return self.op("basicpy.unknown_cast", [result_type], [operand]) diff --git a/python_native/CoreDialects.cpp b/python_native/CoreDialects.cpp index 9d4d183fb..b6180f0c6 100644 --- a/python_native/CoreDialects.cpp +++ b/python_native/CoreDialects.cpp @@ -55,7 +55,7 @@ public: op.getThenBodyBuilder().saveInsertionPoint()); } }, - py::arg("cond"), py::arg("result_types"), + py::arg("result_types"), py::arg("cond"), py::arg("with_else_region") = false); } }; diff --git a/python_native/MlirIr.cpp b/python_native/MlirIr.cpp index aa8033bd2..c6609b484 100644 --- a/python_native/MlirIr.cpp +++ b/python_native/MlirIr.cpp @@ -248,6 +248,15 @@ void PyDialectHelper::bind(py::module m) { R"(Creates a new `func` op, optionally creating an entry block. If an entry block is created, the builder will be positioned to its start.)") + .def("select_op", + [](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue, + PyValue falseValue) -> PyOperationRef { + OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); + Location loc = self.pyOpBuilder.getCurrentLoc(); + return PyOperationRef(opBuilder.create( + loc, conditionValue, trueValue, falseValue)); + }, + py::arg("condition"), py::arg("true_value"), py::arg("false_value")) .def("return_op", [](PyDialectHelper &self, std::vector pyOperands) { OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);