Add support for short-circuit comparisons with scf.if.

pull/1/head
Stella Laurenzo 2020-06-08 17:52:07 -07:00
parent a32219c3bb
commit 1ef3614682
7 changed files with 163 additions and 17 deletions

View File

@ -13,6 +13,12 @@ include "BasicpyDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
// Predicates
//===----------------------------------------------------------------------===//
def BoolOrI1Type : AnyTypeOf<[Basicpy_BoolType, I1], "Python bool or i1">;
//===----------------------------------------------------------------------===//
// Binary operation enum
// The name matches the operation name in the python AST ("Add", "Mult", etc).
@ -101,6 +107,18 @@ def Basicpy_BinaryExprOp : Basicpy_Op<"binary_expr", []> {
let assemblyFormat = "$operation operands attr-dict `:` functional-type(operands, results)";
}
def Basicpy_BoolCastOp : Basicpy_Op<"bool_cast", [NoSideEffect]> {
let summary = "Casts between BoolType and i1";
let description = [{
When interfacing with lower level dialect or progressively lowering
the Python BoolType away, it is often necessary to cast between it and
i1, which is used to represent bool-ness at lower levels.
}];
let arguments = (ins BoolOrI1Type:$operand);
let results = (outs BoolOrI1Type:$result);
let assemblyFormat = "$operand attr-dict `:` type(operands) `->` type(results)";
}
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
ConstantLike, NoSideEffect]> {
let summary = "A boolean constant";

View File

@ -93,10 +93,44 @@ def binary_not_in_():
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64
return x not in y
# @import_global
# def short_circuit():
# x = 1
# y = 2
# z = 3
# return x < y < z
@import_global
def short_circuit():
# CHECK: %[[X:.*]] = constant 1 : i64
# CHECK: %[[Y:.*]] = constant 2 : i64
# CHECK: %[[Z:.*]] = constant 3 : i64
# CHECK: %[[OMEGA:.*]] = constant 5 : i64
x = 1
y = 2
z = 3
omega = 5
# CHECK: %[[FALSE:.*]] = basicpy.bool_constant 0
# CHECK: %[[CMP0:.*]] = basicpy.binary_compare %[[X]] "Lt" %[[Y]]
# CHECK: %[[CMP0_CAST:.*]] = basicpy.bool_cast %[[CMP0]] : !basicpy.BoolType -> i1
# CHECK: %[[IF0:.*]] = scf.if %[[CMP0_CAST]] -> (!basicpy.BoolType) {
# CHECK: %[[CMP1:.*]] = basicpy.binary_compare %[[Y]] "Eq" %[[Z]]
# CHECK: %[[CMP1_CAST:.*]] = basicpy.bool_cast %[[CMP1]] : !basicpy.BoolType -> i1
# CHECK: %[[IF1:.*]] = scf.if %[[CMP1_CAST]] {{.*}} {
# CHECK: %[[CMP2:.*]] = basicpy.binary_compare %[[Z]] "GtE" %[[OMEGA]]
# CHECK: scf.yield %[[CMP2]]
# CHECK: } else {
# CHECK: scf.yield %[[FALSE]]
# CHECK: }
# CHECK: scf.yield %[[IF1]]
# CHECK: } else {
# CHECK: scf.yield %[[FALSE]]
# CHECK: }
# CHECK: %[[RESULT:.*]] = basicpy.unknown_cast %[[IF0]]
# CHECK: return %[[RESULT]]
return x < y == z >= omega
# CHECK-LABEL: nested_short_circuit_expression
@import_global
def nested_short_circuit_expression():
x = 1
y = 2
z = 3
# Verify that the (z + 5) gets nested into the if.
# CHECK: scf.if {{.*}} {
# CHECK-NEXT: constant 6
# CHECK-NEXT: binary_expr "Add"
return x < y == (z + 6)

View File

@ -196,6 +196,11 @@ class ExpressionImporter(BaseNodeVisitor):
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
(ast.dump(node),))
def sub_evaluate(self, sub_node):
sub_importer = ExpressionImporter(self.fctx)
sub_importer.visit(sub_node)
return sub_importer.value
def emit_constant(self, value):
ir_c = self.fctx.ir_c
ir_h = self.fctx.ir_h
@ -236,14 +241,46 @@ class ExpressionImporter(BaseNodeVisitor):
def visit_Compare(self, ast_node):
ir_h = self.fctx.ir_h
if len(ast_node.ops) != 1:
self.fctx.abort("unsupported short-circuit comparison")
left = ExpressionImporter(self.fctx)
left.visit(ast_node.left)
right = ExpressionImporter(self.fctx)
right.visit(ast_node.comparators[0])
self.value = ir_h.basicpy_binary_compare_op(
left.value, right.value, ast_node.ops[0].__class__.__name__).result
if len(ast_node.ops) == 1:
# Simplified single comparison emission.
left = self.sub_evaluate(ast_node.left)
right = self.sub_evaluate(ast_node.comparators[0])
self.value = ir_h.basicpy_binary_compare_op(
left, right, ast_node.ops[0].__class__.__name__).result
else:
# Short-circuit comparison.
false_value = ir_h.basicpy_bool_constant_op(False).result
def emit_next(left_value, comparisons):
operation, right_node = comparisons[0]
comparisons = comparisons[1:]
right_value = self.sub_evaluate(right_node)
compare_result = ir_h.basicpy_binary_compare_op(
left_value, right_value, operation.__class__.__name__).result
# Terminate by yielding the final compare result.
if not comparisons:
return compare_result
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
# all else clauses yield a false value.
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
compare_result).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
compare_result_i1, True)
orig_ip = ir_h.builder.insertion_point
# Build the else clause.
ir_h.builder.insertion_point = else_ip
ir_h.scf_yield_op([false_value])
# Build the then clause.
ir_h.builder.insertion_point = then_ip
nested_result = emit_next(right_value, comparisons)
ir_h.scf_yield_op([nested_result])
ir_h.builder.insertion_point = orig_ip
return if_op.result
self.value = emit_next(self.sub_evaluate(ast_node.left),
list(zip(ast_node.ops, ast_node.comparators)))
def visit_Name(self, ast_node):
if not isinstance(ast_node.ctx, ast.Load):

View File

@ -52,6 +52,9 @@ class DialectHelper(_BaseDialectHelper):
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
return self.op("basicpy.binary_expr", [result_type], [lhs, rhs], attrs)
def basicpy_bool_cast_op(self, result_type, value):
return self.op("basicpy.bool_cast", [result_type], [value])
def basicpy_bool_constant_op(self, value):
c = self.context
ival = 1 if value else 0

View File

@ -80,8 +80,10 @@ target_link_libraries(${extension_target}
NPCOMPBasicpyDialect
NPCOMPNumpyDialect
# Core dialects
MLIRSCF
# Upstream depends
LLVMSupport
MLIRAffineToStandard
MLIRAffineTransforms

View File

@ -9,6 +9,8 @@
#include "MlirIr.h"
#include "NpcompModule.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/BasicpyOps.h"
@ -21,8 +23,40 @@ public:
static void bind(py::module m) {
py::class_<ScfDialectHelper, PyDialectHelper>(m, "ScfDialectHelper")
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>());
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>())
.def("scf_yield_op",
[](ScfDialectHelper &self,
std::vector<PyValue> pyYields) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Value, 4> yields(pyYields.begin(),
pyYields.end());
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
return op.getOperation();
})
.def("scf_if_op",
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
PyValue cond, bool withElseRegion) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
pyResultTypes.end());
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
withElseRegion);
if (withElseRegion) {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint(),
op.getElseBodyBuilder().saveInsertionPoint());
} else {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint());
}
},
py::arg("cond"), py::arg("result_types"),
py::arg("with_else_region") = false);
}
};

View File

@ -31,6 +31,14 @@ struct PyContext;
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
MLIRContext *context);
//===----------------------------------------------------------------------===//
// Direct type bindings
//===----------------------------------------------------------------------===//
static void bindInsertPoint(py::module m) {
py::class_<OpBuilder::InsertPoint>(m, "InsertPoint");
}
//===----------------------------------------------------------------------===//
// Internal only template definitions
// Since it is only legal to use explicit instantiations of templates in
@ -377,6 +385,9 @@ void defineMlirIrModule(py::module m) {
PySymbolTable::bind(m);
PyType::bind(m);
PyValue::bind(m);
// Direct wrappings.
bindInsertPoint(m);
}
//===----------------------------------------------------------------------===//
@ -876,6 +887,13 @@ void PyOpBuilder::bind(py::module m) {
}
self.setCurrentLoc(Location(loc_attr));
})
.def_property("insertion_point",
[](PyOpBuilder &self) {
return self.getBuilder(true).saveInsertionPoint();
},
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
self.getBuilder(false).restoreInsertionPoint(ip);
})
.def("set_file_line_col",
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
unsigned column) {