compiler/ir: support function calls in expressions

This commit is contained in:
Sebastien Bourdeauducq 2014-07-06 21:06:01 +02:00
parent 2aa63ba57d
commit e0ac1193c6
1 changed files with 18 additions and 17 deletions

View File

@ -22,14 +22,14 @@ class _Namespace:
def load(self, builder, name):
return builder.load(self.bindings[name])
def _emit_expr(builder, ns, node):
def _emit_expr(env, builder, ns, node):
if isinstance(node, ast.Name):
return ns.load(builder, node.id)
elif isinstance(node, ast.Num):
return lc.Constant.int(lc.Type.int(), node.n)
elif isinstance(node, ast.BinOp):
left = _emit_expr(builder, ns, node.left)
right = _emit_expr(builder, ns, node.right)
left = _emit_expr(env, builder, ns, node.left)
right = _emit_expr(env, builder, ns, node.right)
mapping = {
ast.Add: builder.add,
ast.Sub: builder.sub,
@ -41,9 +41,9 @@ def _emit_expr(builder, ns, node):
return bf(left, right)
elif isinstance(node, ast.Compare):
comparisons = []
old_comparator = _emit_expr(builder, ns, node.left)
old_comparator = _emit_expr(env, builder, ns, node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = _emit_expr(builder, ns, comparator_a)
comparator = _emit_expr(env, builder, ns, comparator_a)
mapping = {
ast.Eq: lc.ICMP_EQ,
ast.NotEq: lc.ICMP_NE,
@ -59,16 +59,22 @@ def _emit_expr(builder, ns, node):
for comparison in comparisons[1:]:
r = builder.band(r, comparison)
return r
elif isinstance(node, ast.Call):
if node.func.id == "syscall":
return env.emit_syscall(builder, node.args[0].s,
[_emit_expr(env, builder, ns, expr) for expr in node.args[1:]])
else:
raise NotImplementedError
else:
raise NotImplementedError
def _emit_statements(env, builder, ns, stmts):
for stmt in stmts:
if isinstance(stmt, ast.Return):
val = _emit_expr(builder, ns, stmt.value)
val = _emit_expr(env, builder, ns, stmt.value)
builder.ret(val)
elif isinstance(stmt, ast.Assign):
val = _emit_expr(builder, ns, stmt.value)
val = _emit_expr(env, builder, ns, stmt.value)
for target in stmt.targets:
if isinstance(target, ast.Name):
ns.store(builder, val, target.id)
@ -80,7 +86,7 @@ def _emit_statements(env, builder, ns, stmts):
else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge")
condition = _emit_expr(builder, ns, stmt.test)
condition = _emit_expr(env, builder, ns, stmt.test)
builder.cbranch(condition, then_block, else_block)
builder.position_at_end(then_block)
@ -98,12 +104,12 @@ def _emit_statements(env, builder, ns, stmts):
else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge")
condition = _emit_expr(builder, ns, stmt.test)
condition = _emit_expr(env, builder, ns, stmt.test)
builder.cbranch(condition, body_block, else_block)
builder.position_at_end(body_block)
_emit_statements(env, builder, ns, stmt.body)
condition = _emit_expr(builder, ns, stmt.test)
condition = _emit_expr(env, builder, ns, stmt.test)
builder.cbranch(condition, body_block, merge_block)
builder.position_at_end(else_block)
@ -111,13 +117,8 @@ def _emit_statements(env, builder, ns, stmts):
builder.branch(merge_block)
builder.position_at_end(merge_block)
elif isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
call = stmt.value
if call.func.id == "syscall":
env.emit_syscall(builder, call.args[0].s,
[_emit_expr(builder, ns, expr) for expr in call.args[1:]])
else:
raise NotImplementedError
elif isinstance(stmt, ast.Expr):
_emit_expr(env, builder, ns, stmt.value)
else:
raise NotImplementedError