mirror of https://github.com/llvm/torch-mlir
Add support for short-circuit comparisons with scf.if.
parent
a32219c3bb
commit
1ef3614682
|
@ -13,6 +13,12 @@ include "BasicpyDialect.td"
|
|||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BoolOrI1Type : AnyTypeOf<[Basicpy_BoolType, I1], "Python bool or i1">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary operation enum
|
||||
// The name matches the operation name in the python AST ("Add", "Mult", etc).
|
||||
|
@ -101,6 +107,18 @@ def Basicpy_BinaryExprOp : Basicpy_Op<"binary_expr", []> {
|
|||
let assemblyFormat = "$operation operands attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def Basicpy_BoolCastOp : Basicpy_Op<"bool_cast", [NoSideEffect]> {
|
||||
let summary = "Casts between BoolType and i1";
|
||||
let description = [{
|
||||
When interfacing with lower level dialect or progressively lowering
|
||||
the Python BoolType away, it is often necessary to cast between it and
|
||||
i1, which is used to represent bool-ness at lower levels.
|
||||
}];
|
||||
let arguments = (ins BoolOrI1Type:$operand);
|
||||
let results = (outs BoolOrI1Type:$result);
|
||||
let assemblyFormat = "$operand attr-dict `:` type(operands) `->` type(results)";
|
||||
}
|
||||
|
||||
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
|
||||
ConstantLike, NoSideEffect]> {
|
||||
let summary = "A boolean constant";
|
||||
|
|
|
@ -93,10 +93,44 @@ def binary_not_in_():
|
|||
# CHECK: {{.*}} = basicpy.binary_compare {{.*}} "NotIn" {{.*}} : i64, i64
|
||||
return x not in y
|
||||
|
||||
# @import_global
|
||||
# def short_circuit():
|
||||
# x = 1
|
||||
# y = 2
|
||||
# z = 3
|
||||
# return x < y < z
|
||||
@import_global
|
||||
def short_circuit():
|
||||
# CHECK: %[[X:.*]] = constant 1 : i64
|
||||
# CHECK: %[[Y:.*]] = constant 2 : i64
|
||||
# CHECK: %[[Z:.*]] = constant 3 : i64
|
||||
# CHECK: %[[OMEGA:.*]] = constant 5 : i64
|
||||
x = 1
|
||||
y = 2
|
||||
z = 3
|
||||
omega = 5
|
||||
# CHECK: %[[FALSE:.*]] = basicpy.bool_constant 0
|
||||
# CHECK: %[[CMP0:.*]] = basicpy.binary_compare %[[X]] "Lt" %[[Y]]
|
||||
# CHECK: %[[CMP0_CAST:.*]] = basicpy.bool_cast %[[CMP0]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[IF0:.*]] = scf.if %[[CMP0_CAST]] -> (!basicpy.BoolType) {
|
||||
# CHECK: %[[CMP1:.*]] = basicpy.binary_compare %[[Y]] "Eq" %[[Z]]
|
||||
# CHECK: %[[CMP1_CAST:.*]] = basicpy.bool_cast %[[CMP1]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[IF1:.*]] = scf.if %[[CMP1_CAST]] {{.*}} {
|
||||
# CHECK: %[[CMP2:.*]] = basicpy.binary_compare %[[Z]] "GtE" %[[OMEGA]]
|
||||
# CHECK: scf.yield %[[CMP2]]
|
||||
# CHECK: } else {
|
||||
# CHECK: scf.yield %[[FALSE]]
|
||||
# CHECK: }
|
||||
# CHECK: scf.yield %[[IF1]]
|
||||
# CHECK: } else {
|
||||
# CHECK: scf.yield %[[FALSE]]
|
||||
# CHECK: }
|
||||
# CHECK: %[[RESULT:.*]] = basicpy.unknown_cast %[[IF0]]
|
||||
# CHECK: return %[[RESULT]]
|
||||
return x < y == z >= omega
|
||||
|
||||
# CHECK-LABEL: nested_short_circuit_expression
|
||||
@import_global
|
||||
def nested_short_circuit_expression():
|
||||
x = 1
|
||||
y = 2
|
||||
z = 3
|
||||
# Verify that the (z + 5) gets nested into the if.
|
||||
# CHECK: scf.if {{.*}} {
|
||||
# CHECK-NEXT: constant 6
|
||||
# CHECK-NEXT: binary_expr "Add"
|
||||
return x < y == (z + 6)
|
||||
|
|
|
@ -196,6 +196,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
|
||||
(ast.dump(node),))
|
||||
|
||||
def sub_evaluate(self, sub_node):
|
||||
sub_importer = ExpressionImporter(self.fctx)
|
||||
sub_importer.visit(sub_node)
|
||||
return sub_importer.value
|
||||
|
||||
def emit_constant(self, value):
|
||||
ir_c = self.fctx.ir_c
|
||||
ir_h = self.fctx.ir_h
|
||||
|
@ -236,14 +241,46 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
|
||||
def visit_Compare(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
if len(ast_node.ops) != 1:
|
||||
self.fctx.abort("unsupported short-circuit comparison")
|
||||
left = ExpressionImporter(self.fctx)
|
||||
left.visit(ast_node.left)
|
||||
right = ExpressionImporter(self.fctx)
|
||||
right.visit(ast_node.comparators[0])
|
||||
self.value = ir_h.basicpy_binary_compare_op(
|
||||
left.value, right.value, ast_node.ops[0].__class__.__name__).result
|
||||
if len(ast_node.ops) == 1:
|
||||
# Simplified single comparison emission.
|
||||
left = self.sub_evaluate(ast_node.left)
|
||||
right = self.sub_evaluate(ast_node.comparators[0])
|
||||
self.value = ir_h.basicpy_binary_compare_op(
|
||||
left, right, ast_node.ops[0].__class__.__name__).result
|
||||
else:
|
||||
# Short-circuit comparison.
|
||||
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||
|
||||
def emit_next(left_value, comparisons):
|
||||
operation, right_node = comparisons[0]
|
||||
comparisons = comparisons[1:]
|
||||
right_value = self.sub_evaluate(right_node)
|
||||
compare_result = ir_h.basicpy_binary_compare_op(
|
||||
left_value, right_value, operation.__class__.__name__).result
|
||||
# Terminate by yielding the final compare result.
|
||||
if not comparisons:
|
||||
return compare_result
|
||||
|
||||
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
|
||||
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
|
||||
# all else clauses yield a false value.
|
||||
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
|
||||
compare_result).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
|
||||
compare_result_i1, True)
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
# Build the else clause.
|
||||
ir_h.builder.insertion_point = else_ip
|
||||
ir_h.scf_yield_op([false_value])
|
||||
# Build the then clause.
|
||||
ir_h.builder.insertion_point = then_ip
|
||||
nested_result = emit_next(right_value, comparisons)
|
||||
ir_h.scf_yield_op([nested_result])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
return if_op.result
|
||||
|
||||
self.value = emit_next(self.sub_evaluate(ast_node.left),
|
||||
list(zip(ast_node.ops, ast_node.comparators)))
|
||||
|
||||
def visit_Name(self, ast_node):
|
||||
if not isinstance(ast_node.ctx, ast.Load):
|
||||
|
|
|
@ -52,6 +52,9 @@ class DialectHelper(_BaseDialectHelper):
|
|||
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
|
||||
return self.op("basicpy.binary_expr", [result_type], [lhs, rhs], attrs)
|
||||
|
||||
def basicpy_bool_cast_op(self, result_type, value):
|
||||
return self.op("basicpy.bool_cast", [result_type], [value])
|
||||
|
||||
def basicpy_bool_constant_op(self, value):
|
||||
c = self.context
|
||||
ival = 1 if value else 0
|
||||
|
|
|
@ -80,8 +80,10 @@ target_link_libraries(${extension_target}
|
|||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
|
||||
# Core dialects
|
||||
MLIRSCF
|
||||
|
||||
# Upstream depends
|
||||
|
||||
LLVMSupport
|
||||
MLIRAffineToStandard
|
||||
MLIRAffineTransforms
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
#include "MlirIr.h"
|
||||
#include "NpcompModule.h"
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOps.h"
|
||||
|
||||
|
@ -21,8 +23,40 @@ public:
|
|||
|
||||
static void bind(py::module m) {
|
||||
py::class_<ScfDialectHelper, PyDialectHelper>(m, "ScfDialectHelper")
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>());
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>())
|
||||
.def("scf_yield_op",
|
||||
[](ScfDialectHelper &self,
|
||||
std::vector<PyValue> pyYields) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Value, 4> yields(pyYields.begin(),
|
||||
pyYields.end());
|
||||
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def("scf_if_op",
|
||||
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
||||
PyValue cond, bool withElseRegion) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
||||
withElseRegion);
|
||||
if (withElseRegion) {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint(),
|
||||
op.getElseBodyBuilder().saveInsertionPoint());
|
||||
} else {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint());
|
||||
}
|
||||
},
|
||||
py::arg("cond"), py::arg("result_types"),
|
||||
py::arg("with_else_region") = false);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,14 @@ struct PyContext;
|
|||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||
MLIRContext *context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Direct type bindings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void bindInsertPoint(py::module m) {
|
||||
py::class_<OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Internal only template definitions
|
||||
// Since it is only legal to use explicit instantiations of templates in
|
||||
|
@ -377,6 +385,9 @@ void defineMlirIrModule(py::module m) {
|
|||
PySymbolTable::bind(m);
|
||||
PyType::bind(m);
|
||||
PyValue::bind(m);
|
||||
|
||||
// Direct wrappings.
|
||||
bindInsertPoint(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -876,6 +887,13 @@ void PyOpBuilder::bind(py::module m) {
|
|||
}
|
||||
self.setCurrentLoc(Location(loc_attr));
|
||||
})
|
||||
.def_property("insertion_point",
|
||||
[](PyOpBuilder &self) {
|
||||
return self.getBuilder(true).saveInsertionPoint();
|
||||
},
|
||||
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
||||
self.getBuilder(false).restoreInsertionPoint(ip);
|
||||
})
|
||||
.def("set_file_line_col",
|
||||
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) {
|
||||
|
|
Loading…
Reference in New Issue