mirror of https://github.com/llvm/torch-mlir
Add IfExp emission.
parent
e18e8e0a96
commit
22cbe044c2
|
@ -75,3 +75,20 @@ def logical_not():
|
||||||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||||
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||||
return not x
|
return not x
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @conditional
|
||||||
|
@import_global
|
||||||
|
def conditional():
|
||||||
|
# CHECK: %[[X:.*]] = constant 1
|
||||||
|
x = 1
|
||||||
|
# CHECK: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||||
|
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
|
||||||
|
# CHECK: %[[TWO:.*]] = constant 2 : i64
|
||||||
|
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]
|
||||||
|
# CHECK: scf.yield %[[TWO_CAST]]
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[THREE:.*]] = constant 3 : i64
|
||||||
|
# CHECK: %[[THREE_CAST:.*]] = basicpy.unknown_cast %[[THREE]]
|
||||||
|
# CHECK: scf.yield %[[THREE_CAST]]
|
||||||
|
# CHECK: }
|
||||||
|
return 2 if x else 3
|
||||||
|
|
|
@ -309,6 +309,32 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
self.value = emit_next(self.sub_evaluate(ast_node.left),
|
self.value = emit_next(self.sub_evaluate(ast_node.left),
|
||||||
list(zip(ast_node.ops, ast_node.comparators)))
|
list(zip(ast_node.ops, ast_node.comparators)))
|
||||||
|
|
||||||
|
def visit_IfExp(self, ast_node):
|
||||||
|
ir_h = self.fctx.ir_h
|
||||||
|
test_result = ir_h.basicpy_to_boolean_op(self.sub_evaluate(
|
||||||
|
ast_node.test)).result
|
||||||
|
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||||
|
test_result, True)
|
||||||
|
|
||||||
|
orig_ip = ir_h.builder.insertion_point
|
||||||
|
# Build the then clause
|
||||||
|
ir_h.builder.insertion_point = then_ip
|
||||||
|
then_result = self.sub_evaluate(ast_node.body)
|
||||||
|
ir_h.scf_yield_op([
|
||||||
|
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||||
|
then_result).result
|
||||||
|
])
|
||||||
|
# Build the then clause.
|
||||||
|
ir_h.builder.insertion_point = else_ip
|
||||||
|
orelse_result = self.sub_evaluate(ast_node.orelse)
|
||||||
|
ir_h.scf_yield_op([
|
||||||
|
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||||
|
orelse_result).result
|
||||||
|
])
|
||||||
|
ir_h.builder.insertion_point = orig_ip
|
||||||
|
|
||||||
|
self.value = if_op.result
|
||||||
|
|
||||||
def visit_Name(self, ast_node):
|
def visit_Name(self, ast_node):
|
||||||
if not isinstance(ast_node.ctx, ast.Load):
|
if not isinstance(ast_node.ctx, ast.Load):
|
||||||
self.fctx.abort("Unsupported expression name context type %s" %
|
self.fctx.abort("Unsupported expression name context type %s" %
|
||||||
|
|
Loading…
Reference in New Issue