compiler: split into transforms and py2llvm

This commit is contained in:
Sebastien Bourdeauducq 2014-09-05 22:18:31 +08:00
parent ef1f8787dc
commit a647e1104d
15 changed files with 61 additions and 61 deletions

View File

@ -1,10 +1,10 @@
from artiq.compiler.inline import inline from artiq.transforms.inline import inline
from artiq.compiler.lower_units import lower_units from artiq.transforms.lower_units import lower_units
from artiq.compiler.fold_constants import fold_constants from artiq.transforms.fold_constants import fold_constants
from artiq.compiler.unroll_loops import unroll_loops from artiq.transforms.unroll_loops import unroll_loops
from artiq.compiler.interleave import interleave from artiq.transforms.interleave import interleave
from artiq.compiler.lower_time import lower_time from artiq.transforms.lower_time import lower_time
from artiq.compiler.ir import get_runtime_binary from artiq.py2llvm import get_runtime_binary
class Core: class Core:

View File

@ -1,7 +1,7 @@
from llvm import core as lc from llvm import core as lc
from llvm import target as lt from llvm import target as lt
from artiq.compiler import ir_values from artiq.py2llvm import values
lt.initialize_all() lt.initialize_all()
@ -21,9 +21,9 @@ _chr_to_type = {
} }
_chr_to_value = { _chr_to_value = {
"n": lambda: ir_values.VNone(), "n": lambda: values.VNone(),
"i": lambda: ir_values.VInt(), "i": lambda: values.VInt(),
"I": lambda: ir_values.VInt(64) "I": lambda: values.VInt(64)
} }

View File

@ -1,19 +1,19 @@
from llvm import core as lc from llvm import core as lc
from llvm import passes as lp from llvm import passes as lp
from artiq.compiler import ir_infer_types, ir_ast_body, ir_values from artiq.py2llvm import infer_types, ast_body, values
def compile_function(module, env, funcdef): def _compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), []) function_type = lc.Type.function(lc.Type.void(), [])
function = module.add_function(function_type, funcdef.name) function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry") bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb) builder = lc.Builder.new(bb)
ns = ir_infer_types.infer_types(env, funcdef) ns = infer_types.infer_types(env, funcdef)
for k, v in ns.items(): for k, v in ns.items():
v.alloca(builder, k) v.alloca(builder, k)
visitor = ir_ast_body.Visitor(env, ns, builder) visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body) visitor.visit_statements(funcdef.body)
builder.ret_void() builder.ret_void()
@ -21,9 +21,9 @@ def compile_function(module, env, funcdef):
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
module = lc.Module.new("main") module = lc.Module.new("main")
env.init_module(module) env.init_module(module)
ir_values.init_module(module) values.init_module(module)
compile_function(module, env, funcdef) _compile_function(module, env, funcdef)
pass_manager = lp.PassManager.new() pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG) pass_manager.add(lp.PASS_MEM2REG)

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.compiler import ir_values from artiq.py2llvm import values
class Visitor: class Visitor:
@ -29,9 +29,9 @@ class Visitor:
def _visit_expr_NameConstant(self, node): def _visit_expr_NameConstant(self, node):
v = node.value v = node.value
if v is None: if v is None:
r = ir_values.VNone() r = values.VNone()
elif isinstance(v, bool): elif isinstance(v, bool):
r = ir_values.VBool() r = values.VBool()
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
@ -42,9 +42,9 @@ class Visitor:
n = node.n n = node.n
if isinstance(n, int): if isinstance(n, int):
if abs(n) < 2**31: if abs(n) < 2**31:
r = ir_values.VInt() r = values.VInt()
else: else:
r = ir_values.VInt(64) r = values.VInt(64)
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
@ -53,28 +53,28 @@ class Visitor:
def _visit_expr_UnaryOp(self, node): def _visit_expr_UnaryOp(self, node):
ast_unops = { ast_unops = {
ast.Invert: ir_values.operators.inv, ast.Invert: values.operators.inv,
ast.Not: ir_values.operators.not_, ast.Not: values.operators.not_,
ast.UAdd: ir_values.operators.pos, ast.UAdd: values.operators.pos,
ast.USub: ir_values.operators.neg ast.USub: values.operators.neg
} }
return ast_unops[type(node.op)](self.visit_expression(node.operand), return ast_unops[type(node.op)](self.visit_expression(node.operand),
self.builder) self.builder)
def _visit_expr_BinOp(self, node): def _visit_expr_BinOp(self, node):
ast_binops = { ast_binops = {
ast.Add: ir_values.operators.add, ast.Add: values.operators.add,
ast.Sub: ir_values.operators.sub, ast.Sub: values.operators.sub,
ast.Mult: ir_values.operators.mul, ast.Mult: values.operators.mul,
ast.Div: ir_values.operators.truediv, ast.Div: values.operators.truediv,
ast.FloorDiv: ir_values.operators.floordiv, ast.FloorDiv: values.operators.floordiv,
ast.Mod: ir_values.operators.mod, ast.Mod: values.operators.mod,
ast.Pow: ir_values.operators.pow, ast.Pow: values.operators.pow,
ast.LShift: ir_values.operators.lshift, ast.LShift: values.operators.lshift,
ast.RShift: ir_values.operators.rshift, ast.RShift: values.operators.rshift,
ast.BitOr: ir_values.operators.or_, ast.BitOr: values.operators.or_,
ast.BitXor: ir_values.operators.xor, ast.BitXor: values.operators.xor,
ast.BitAnd: ir_values.operators.and_ ast.BitAnd: values.operators.and_
} }
return ast_binops[type(node.op)](self.visit_expression(node.left), return ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right), self.visit_expression(node.right),
@ -82,12 +82,12 @@ class Visitor:
def _visit_expr_Compare(self, node): def _visit_expr_Compare(self, node):
ast_cmps = { ast_cmps = {
ast.Eq: ir_values.operators.eq, ast.Eq: values.operators.eq,
ast.NotEq: ir_values.operators.ne, ast.NotEq: values.operators.ne,
ast.Lt: ir_values.operators.lt, ast.Lt: values.operators.lt,
ast.LtE: ir_values.operators.le, ast.LtE: values.operators.le,
ast.Gt: ir_values.operators.gt, ast.Gt: values.operators.gt,
ast.GtE: ir_values.operators.ge ast.GtE: values.operators.ge
} }
comparisons = [] comparisons = []
old_comparator = self.visit_expression(node.left) old_comparator = self.visit_expression(node.left)
@ -99,23 +99,23 @@ class Visitor:
old_comparator = comparator old_comparator = comparator
r = comparisons[0] r = comparisons[0]
for comparison in comparisons[1:]: for comparison in comparisons[1:]:
r = ir_values.operators.and_(r, comparison) r = values.operators.and_(r, comparison)
return r return r
def _visit_expr_Call(self, node): def _visit_expr_Call(self, node):
ast_unfuns = { ast_unfuns = {
"bool": ir_values.operators.bool, "bool": values.operators.bool,
"int": ir_values.operators.int, "int": values.operators.int,
"int64": ir_values.operators.int64, "int64": values.operators.int64,
"round": ir_values.operators.round, "round": values.operators.round,
"round64": ir_values.operators.round64, "round64": values.operators.round64,
} }
fn = node.func.id fn = node.func.id
if fn in ast_unfuns: if fn in ast_unfuns:
return ast_unfuns[fn](self.visit_expression(node.args[0]), return ast_unfuns[fn](self.visit_expression(node.args[0]),
self.builder) self.builder)
elif fn == "Fraction": elif fn == "Fraction":
r = ir_values.VFraction() r = values.VFraction()
if self.builder is not None: if self.builder is not None:
numerator = self.visit_expression(node.args[0]) numerator = self.visit_expression(node.args[0])
denominator = self.visit_expression(node.args[1]) denominator = self.visit_expression(node.args[1])
@ -164,7 +164,7 @@ class Visitor:
else_block = function.append_basic_block("i_else") else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge") merge_block = function.append_basic_block("i_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), condition = values.operators.bool(self.visit_expression(node.test),
self.builder) self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder), self.builder.cbranch(condition.get_ssa_value(self.builder),
then_block, else_block) then_block, else_block)
@ -185,14 +185,14 @@ class Visitor:
else_block = function.append_basic_block("w_else") else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge") merge_block = function.append_basic_block("w_merge")
condition = ir_values.operators.bool( condition = values.operators.bool(
self.visit_expression(node.test), self.builder) self.visit_expression(node.test), self.builder)
self.builder.cbranch( self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, else_block) condition.get_ssa_value(self.builder), body_block, else_block)
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
condition = ir_values.operators.bool( condition = values.operators.bool(
self.visit_expression(node.test), self.builder) self.visit_expression(node.test), self.builder)
self.builder.cbranch( self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block) condition.get_ssa_value(self.builder), body_block, merge_block)

View File

@ -2,7 +2,7 @@ import ast
from operator import itemgetter from operator import itemgetter
from copy import deepcopy from copy import deepcopy
from artiq.compiler.ir_ast_body import Visitor from artiq.py2llvm.ast_body import Visitor
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):

View File

View File

@ -1,7 +1,7 @@
import ast import ast
import operator import operator
from artiq.compiler.tools import * from artiq.transforms.tools import *
from artiq.language.core import int64, round64 from artiq.language.core import int64, round64

View File

@ -4,7 +4,7 @@ import inspect
import textwrap import textwrap
import ast import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.transforms.tools import eval_ast, value_to_ast
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units

View File

@ -1,7 +1,7 @@
import ast import ast
import types import types
from artiq.compiler.tools import * from artiq.transforms.tools import *
# -1 statement duration could not be pre-determined # -1 statement duration could not be pre-determined

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.compiler.tools import value_to_ast from artiq.transforms.tools import value_to_ast
from artiq.language.core import int64 from artiq.language.core import int64

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.compiler.tools import value_to_ast from artiq.transforms.tools import value_to_ast
from artiq.language import units from artiq.language import units

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.transforms.tools import eval_ast, value_to_ast
def _count_stmts(node): def _count_stmts(node):