diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index e80a4af63..d96504510 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -22,14 +22,14 @@ class _Namespace: def load(self, builder, name): return builder.load(self.bindings[name]) -def _emit_expr(builder, ns, node): +def _emit_expr(env, builder, ns, node): if isinstance(node, ast.Name): return ns.load(builder, node.id) elif isinstance(node, ast.Num): return lc.Constant.int(lc.Type.int(), node.n) elif isinstance(node, ast.BinOp): - left = _emit_expr(builder, ns, node.left) - right = _emit_expr(builder, ns, node.right) + left = _emit_expr(env, builder, ns, node.left) + right = _emit_expr(env, builder, ns, node.right) mapping = { ast.Add: builder.add, ast.Sub: builder.sub, @@ -41,9 +41,9 @@ def _emit_expr(builder, ns, node): return bf(left, right) elif isinstance(node, ast.Compare): comparisons = [] - old_comparator = _emit_expr(builder, ns, node.left) + old_comparator = _emit_expr(env, builder, ns, node.left) for op, comparator_a in zip(node.ops, node.comparators): - comparator = _emit_expr(builder, ns, comparator_a) + comparator = _emit_expr(env, builder, ns, comparator_a) mapping = { ast.Eq: lc.ICMP_EQ, ast.NotEq: lc.ICMP_NE, @@ -59,16 +59,22 @@ def _emit_expr(builder, ns, node): for comparison in comparisons[1:]: r = builder.band(r, comparison) return r + elif isinstance(node, ast.Call): + if node.func.id == "syscall": + return env.emit_syscall(builder, node.args[0].s, + [_emit_expr(env, builder, ns, expr) for expr in node.args[1:]]) + else: + raise NotImplementedError else: raise NotImplementedError def _emit_statements(env, builder, ns, stmts): for stmt in stmts: if isinstance(stmt, ast.Return): - val = _emit_expr(builder, ns, stmt.value) + val = _emit_expr(env, builder, ns, stmt.value) builder.ret(val) elif isinstance(stmt, ast.Assign): - val = _emit_expr(builder, ns, stmt.value) + val = _emit_expr(env, builder, ns, stmt.value) for target in stmt.targets: if isinstance(target, ast.Name): ns.store(builder, val, target.id) @@ -80,7 +86,7 @@ def _emit_statements(env, builder, ns, stmts): else_block = function.append_basic_block("i_else") merge_block = function.append_basic_block("i_merge") - condition = _emit_expr(builder, ns, stmt.test) + condition = _emit_expr(env, builder, ns, stmt.test) builder.cbranch(condition, then_block, else_block) builder.position_at_end(then_block) @@ -98,12 +104,12 @@ def _emit_statements(env, builder, ns, stmts): else_block = function.append_basic_block("w_else") merge_block = function.append_basic_block("w_merge") - condition = _emit_expr(builder, ns, stmt.test) + condition = _emit_expr(env, builder, ns, stmt.test) builder.cbranch(condition, body_block, else_block) builder.position_at_end(body_block) _emit_statements(env, builder, ns, stmt.body) - condition = _emit_expr(builder, ns, stmt.test) + condition = _emit_expr(env, builder, ns, stmt.test) builder.cbranch(condition, body_block, merge_block) builder.position_at_end(else_block) @@ -111,13 +117,8 @@ def _emit_statements(env, builder, ns, stmts): builder.branch(merge_block) builder.position_at_end(merge_block) - elif isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call): - call = stmt.value - if call.func.id == "syscall": - env.emit_syscall(builder, call.args[0].s, - [_emit_expr(builder, ns, expr) for expr in call.args[1:]]) - else: - raise NotImplementedError + elif isinstance(stmt, ast.Expr): + _emit_expr(env, builder, ns, stmt.value) else: raise NotImplementedError