mirror of https://github.com/llvm/torch-mlir
Add IfExp emission.
parent
e18e8e0a96
commit
22cbe044c2
|
@ -75,3 +75,20 @@ def logical_not():
|
|||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||
return not x
|
||||
|
||||
# CHECK-LABEL: func @conditional
|
||||
@import_global
|
||||
def conditional():
|
||||
# CHECK: %[[X:.*]] = constant 1
|
||||
x = 1
|
||||
# CHECK: %[[CONDITION:.*]] = basicpy.to_boolean %[[X]]
|
||||
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
|
||||
# CHECK: %[[TWO:.*]] = constant 2 : i64
|
||||
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]
|
||||
# CHECK: scf.yield %[[TWO_CAST]]
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[THREE:.*]] = constant 3 : i64
|
||||
# CHECK: %[[THREE_CAST:.*]] = basicpy.unknown_cast %[[THREE]]
|
||||
# CHECK: scf.yield %[[THREE_CAST]]
|
||||
# CHECK: }
|
||||
return 2 if x else 3
|
||||
|
|
|
@ -309,6 +309,32 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
self.value = emit_next(self.sub_evaluate(ast_node.left),
|
||||
list(zip(ast_node.ops, ast_node.comparators)))
|
||||
|
||||
def visit_IfExp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
test_result = ir_h.basicpy_to_boolean_op(self.sub_evaluate(
|
||||
ast_node.test)).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||
test_result, True)
|
||||
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
# Build the then clause
|
||||
ir_h.builder.insertion_point = then_ip
|
||||
then_result = self.sub_evaluate(ast_node.body)
|
||||
ir_h.scf_yield_op([
|
||||
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
then_result).result
|
||||
])
|
||||
# Build the then clause.
|
||||
ir_h.builder.insertion_point = else_ip
|
||||
orelse_result = self.sub_evaluate(ast_node.orelse)
|
||||
ir_h.scf_yield_op([
|
||||
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
orelse_result).result
|
||||
])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
|
||||
self.value = if_op.result
|
||||
|
||||
def visit_Name(self, ast_node):
|
||||
if not isinstance(ast_node.ctx, ast.Load):
|
||||
self.fctx.abort("Unsupported expression name context type %s" %
|
||||
|
|
Loading…
Reference in New Issue