mirror of https://github.com/llvm/torch-mlir
Add boolean/logical operations (and, or, not).
* Adds a new to_boolean op to evaluate a value as a truthy i1 * Uses cascading scf.if ops to properly evaluate and/or sequences (short-circuit and original value returning) * Adds a helper to construct select ops and uses it to implement 'not'pull/1/head
parent
b0a80e04f1
commit
e18e8e0a96
|
@ -240,6 +240,19 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
|
||||||
let assemblyFormat = "attr-dict `:` type($result)";
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Basicpy_ToBooleanOp : Basicpy_Op<"to_boolean", [NoSideEffect]> {
|
||||||
|
let summary = "Evaluates an input to an i1 boolean value";
|
||||||
|
let description = [{
|
||||||
|
Applies the rules for interpreting a type as a boolean, returning an i1
|
||||||
|
indicating the truthiness of the operand. Since the output of this op
|
||||||
|
is intended to drive lower-level control flow, the i1 type is used (not
|
||||||
|
the user level BoolType).
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyType:$operand);
|
||||||
|
let results = (outs I1:$result);
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||||
|
}
|
||||||
|
|
||||||
def Basicpy_UnknownCastOp : Basicpy_Op<"unknown_cast", [NoSideEffect]> {
|
def Basicpy_UnknownCastOp : Basicpy_Op<"unknown_cast", [NoSideEffect]> {
|
||||||
let summary = "Casts to and from the UnknownType";
|
let summary = "Casts to and from the UnknownType";
|
||||||
let arguments = (ins AnyType:$input);
|
let arguments = (ins AnyType:$input);
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
from npcomp.compiler.frontend import *
|
||||||
|
|
||||||
|
|
||||||
|
def import_global(f):
|
||||||
|
fe = ImportFrontend()
|
||||||
|
fe.import_global_function(f)
|
||||||
|
print("// -----")
|
||||||
|
print(fe.ir_module.to_asm())
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @logical_and
|
||||||
|
@import_global
|
||||||
|
def logical_and():
|
||||||
|
# CHECK: %[[X:.*]] = constant 1
|
||||||
|
# CHECK: %[[Y:.*]] = constant 0
|
||||||
|
# CHECK: %[[Z:.*]] = constant 2
|
||||||
|
x = 1
|
||||||
|
y = 0
|
||||||
|
z = 2
|
||||||
|
# CHECK: %[[XBOOL:.*]] = basicpy.to_boolean %[[X]]
|
||||||
|
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||||
|
# CHECK: %[[YBOOL:.*]] = basicpy.to_boolean %[[Y]]
|
||||||
|
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||||
|
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
|
||||||
|
# CHECK: scf.yield %[[ZCAST]]
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
|
||||||
|
# CHECK: scf.yield %[[YCAST]]
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: %[[IF1CAST:.*]] = basicpy.unknown_cast %[[IF1]]
|
||||||
|
# CHECK: scf.yield %[[IF1CAST]]
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
|
||||||
|
# CHECK: scf.yield %[[XCAST]]
|
||||||
|
# CHECK: }
|
||||||
|
return x and y and z
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @logical_or
|
||||||
|
@import_global
|
||||||
|
def logical_or():
|
||||||
|
# CHECK: %[[X:.*]] = constant 0
|
||||||
|
# CHECK: %[[Y:.*]] = constant 1
|
||||||
|
# CHECK: %[[Z:.*]] = constant 2
|
||||||
|
# CHECK: %[[XBOOL:.*]] = basicpy.to_boolean %[[X]]
|
||||||
|
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||||
|
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
|
||||||
|
# CHECK: scf.yield %[[XCAST]]
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[YBOOL:.*]] = basicpy.to_boolean %[[Y]]
|
||||||
|
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||||
|
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
|
||||||
|
# CHECK: scf.yield %[[YCAST]]
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
|
||||||
|
# CHECK: scf.yield %[[ZCAST]]
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: %[[IF1CAST:.*]] = basicpy.unknown_cast %[[IF1]]
|
||||||
|
# CHECK: scf.yield %[[IF1CAST]]
|
||||||
|
# CHECK: }
|
||||||
|
x = 0
|
||||||
|
y = 1
|
||||||
|
z = 2
|
||||||
|
return x or y or z
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @logical_not
|
||||||
|
@import_global
|
||||||
|
def logical_not():
|
||||||
|
# CHECK: %[[X:.*]] = constant 1
|
||||||
|
x = 1
|
||||||
|
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant 1
|
||||||
|
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant 0
|
||||||
|
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||||
|
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||||
|
return not x
|
|
@ -237,6 +237,41 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
ir_h.basicpy_UnknownType, left, right,
|
ir_h.basicpy_UnknownType, left, right,
|
||||||
ast_node.op.__class__.__name__).result
|
ast_node.op.__class__.__name__).result
|
||||||
|
|
||||||
|
def visit_BoolOp(self, ast_node):
|
||||||
|
ir_h = self.fctx.ir_h
|
||||||
|
if isinstance(ast_node.op, ast.And):
|
||||||
|
return_first_true = False
|
||||||
|
elif isinstance(ast_node.op, ast.Or):
|
||||||
|
return_first_true = True
|
||||||
|
else:
|
||||||
|
self.fctx.abort("unknown bool op %r" % (ast.dump(ast_node.op)))
|
||||||
|
|
||||||
|
def emit_next(next_nodes):
|
||||||
|
next_node = next_nodes[0]
|
||||||
|
next_nodes = next_nodes[1:]
|
||||||
|
next_value = self.sub_evaluate(next_node)
|
||||||
|
if not next_nodes:
|
||||||
|
return next_value
|
||||||
|
condition_value = ir_h.basicpy_to_boolean_op(next_value).result
|
||||||
|
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||||
|
condition_value, True)
|
||||||
|
orig_ip = ir_h.builder.insertion_point
|
||||||
|
# Short-circuit return case.
|
||||||
|
ir_h.builder.insertion_point = then_ip if return_first_true else else_ip
|
||||||
|
next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||||
|
next_value).result
|
||||||
|
ir_h.scf_yield_op([next_value_casted])
|
||||||
|
# Nested evaluate next case.
|
||||||
|
ir_h.builder.insertion_point = else_ip if return_first_true else then_ip
|
||||||
|
nested_value = emit_next(next_nodes)
|
||||||
|
nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op(
|
||||||
|
ir_h.basicpy_UnknownType, nested_value).result
|
||||||
|
ir_h.scf_yield_op([nested_value_casted])
|
||||||
|
ir_h.builder.insertion_point = orig_ip
|
||||||
|
return if_op.result
|
||||||
|
|
||||||
|
self.value = emit_next(ast_node.values)
|
||||||
|
|
||||||
def visit_Compare(self, ast_node):
|
def visit_Compare(self, ast_node):
|
||||||
# Short-circuit comparison (degenerates to binary comparison when just
|
# Short-circuit comparison (degenerates to binary comparison when just
|
||||||
# two operands).
|
# two operands).
|
||||||
|
@ -284,6 +319,20 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
|
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
def visit_UnaryOp(self, ast_node):
|
||||||
|
ir_h = self.fctx.ir_h
|
||||||
|
op = ast_node.op
|
||||||
|
operand_value = self.sub_evaluate(ast_node.operand)
|
||||||
|
if isinstance(op, ast.Not):
|
||||||
|
# Special handling for logical-not.
|
||||||
|
condition_value = ir_h.basicpy_to_boolean_op(operand_value).result
|
||||||
|
true_value = ir_h.basicpy_bool_constant_op(True).result
|
||||||
|
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||||
|
self.value = ir_h.select_op(condition_value, false_value,
|
||||||
|
true_value).result
|
||||||
|
else:
|
||||||
|
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
|
||||||
|
|
||||||
if sys.version_info < (3, 8, 0):
|
if sys.version_info < (3, 8, 0):
|
||||||
# <3.8 breaks these out into separate AST classes.
|
# <3.8 breaks these out into separate AST classes.
|
||||||
def visit_Num(self, ast_node):
|
def visit_Num(self, ast_node):
|
||||||
|
|
|
@ -90,6 +90,9 @@ class DialectHelper(_BaseDialectHelper):
|
||||||
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
||||||
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
||||||
|
|
||||||
|
def basicpy_to_boolean_op(self, value):
|
||||||
|
return self.op("basicpy.to_boolean", [self.i1_type], [value])
|
||||||
|
|
||||||
def basicpy_unknown_cast_op(self, result_type, operand):
|
def basicpy_unknown_cast_op(self, result_type, operand):
|
||||||
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ public:
|
||||||
op.getThenBodyBuilder().saveInsertionPoint());
|
op.getThenBodyBuilder().saveInsertionPoint());
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
py::arg("cond"), py::arg("result_types"),
|
py::arg("result_types"), py::arg("cond"),
|
||||||
py::arg("with_else_region") = false);
|
py::arg("with_else_region") = false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -248,6 +248,15 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
R"(Creates a new `func` op, optionally creating an entry block.
|
R"(Creates a new `func` op, optionally creating an entry block.
|
||||||
If an entry block is created, the builder will be positioned
|
If an entry block is created, the builder will be positioned
|
||||||
to its start.)")
|
to its start.)")
|
||||||
|
.def("select_op",
|
||||||
|
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
||||||
|
PyValue falseValue) -> PyOperationRef {
|
||||||
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
|
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||||
|
return PyOperationRef(opBuilder.create<SelectOp>(
|
||||||
|
loc, conditionValue, trueValue, falseValue));
|
||||||
|
},
|
||||||
|
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
||||||
.def("return_op",
|
.def("return_op",
|
||||||
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
|
|
Loading…
Reference in New Issue