compiler: new architecture for type inference and LLVM code emission

This commit is contained in:
Sebastien Bourdeauducq 2014-08-16 23:20:16 +08:00
parent 9189ad5fab
commit 3e4cbba018
4 changed files with 321 additions and 179 deletions

View File

@ -1,179 +0,0 @@
from collections import namedtuple
import ast
from artiq.language import units
TBool = namedtuple("TBool", "")
TFloat = namedtuple("TFloat", "")
TInt = namedtuple("TInt", "nbits")
TFraction = namedtuple("TFraction", "")
class TypeAnnotation:
def __init__(self, t, unit=None):
self.t = t
self.unit = unit
def __repr__(self):
r = "TypeAnnotation("+str(self.t)
if self.unit is not None:
r += " <unit:"+str(self.unit.name)+">"
r += ")"
return r
def __eq__(self, other):
return self.t == other.t and self.unit == other.unit
def promote(self, ta):
if ta.unit != self.unit:
raise units.DimensionError
if isinstance(self.t, TBool):
if not isinstance(ta.t, TBool):
raise TypeError
elif isinstance(self.t, TFloat):
if not isinstance(ta.t, TFloat):
raise TypeError
elif isinstance(self.t, TInt):
if isinstance(ta.t, TInt):
self.t = TInt(max(self.t.nbits, ta.t.nbits))
else:
raise TypeError
elif isinstance(self.t, TFraction):
if not isinstance(ta.t, TFraction):
raise TypeError
else:
raise TypeError
def _get_addsub_type(l, r):
if l.unit != r.unit:
raise units.DimensionError
if isinstance(l.t, TFloat):
if isinstance(r.t, (TFloat, TInt, TFraction)):
return l
else:
raise TypeError
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit)
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
return r
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
return r
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
return l
raise TypeError
def _get_mul_type(l, r):
unit = l.unit
if r.unit is not None:
if unit is None:
unit = r.unit
else:
raise NotImplementedError
if isinstance(l.t, TFloat):
if isinstance(r.t, (TFloat, TInt, TFraction)):
return TypeAnnotation(TFloat(), unit)
else:
raise TypeError
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
return TypeAnnotation(r.t, unit)
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
return TypeAnnotation(TFloat(), unit)
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
return TypeAnnotation(TFraction(), unit)
raise TypeError
def _get_div_unit(l, r):
if l.unit is not None and r.unit is None:
return l.unit
elif l.unit == r.unit:
return None
else:
raise NotImplementedError
def _get_truediv_type(l, r):
unit = _get_div_unit(l, r)
if isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
return TypeAnnotation(TFraction(), unit)
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
return TypeAnnotation(TFraction(), unit)
else:
return TypeAnnotation(TFloat(), unit)
def _get_floordiv_type(l, r):
unit = _get_div_unit(l, r)
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
elif isinstance(l.t, (TInt, TFloat)) and isinstance(r.t, TFloat):
return TypeAnnotation(TFloat(), unit)
elif isinstance(l.t, TFloat) and isinstance(r.t, (TInt, TFloat)):
return TypeAnnotation(TFloat(), unit)
elif (isinstance(l.t, TFloat) and isinstance(r.t, TFraction)) or (isinstance(l.t, TFraction) and isinstance(r.t, TFloat)):
return TypeAnnotation(TInt(64), unit)
elif isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
return TypeAnnotation(TFraction(), unit)
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
return TypeAnnotation(TFraction(), unit)
else:
raise NotImplementedError
def _get_call_type(sym_to_type, node):
fn = node.func.id
if fn == "bool":
return TypeAnnotation(TBool())
elif fn == "float":
return TypeAnnotation(TFloat())
elif fn == "int" or fn == "round":
return TypeAnnotation(TInt(32))
elif fn == "int64" or fn == "round64":
return TypeAnnotation(TInt(64))
elif fn == "Fraction":
return TypeAnnotation(TFraction())
elif fn == "Quantity":
ta = _get_expr_type(sym_to_type, node.args[0])
ta.unit = getattr(units, node.args[1].id)
return ta
else:
raise NotImplementedError
def _get_expr_type(sym_to_type, node):
if isinstance(node, ast.NameConstant):
if isinstance(node.value, bool):
return TypeAnnotation(TBool())
else:
raise TypeError
elif isinstance(node, ast.Num):
if isinstance(node.n, int):
nbits = 32 if abs(node.n) < 2**31 else 64
return TypeAnnotation(TInt(nbits))
elif isinstance(node.n, float):
return TypeAnnotation(TFloat())
else:
raise TypeError
elif isinstance(node, ast.Name):
return sym_to_type[node.id]
elif isinstance(node, ast.UnaryOp):
return _get_expr_type(sym_to_type, node.operand)
elif isinstance(node, ast.Compare):
return TypeAnnotation(TBool())
elif isinstance(node, ast.BinOp):
l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right)
if isinstance(node.op, (ast.Add, ast.Sub)):
return _get_addsub_type(l, r)
elif isinstance(node.op, ast.Mul):
return _get_mul_type(l, r)
elif isinstance(node.op, ast.Div):
return _get_truediv_type(l, r)
elif isinstance(node.op, ast.FloorDiv):
return _get_floordiv_type(l, r)
else:
raise NotImplementedError
elif isinstance(node, ast.Call):
return _get_call_type(sym_to_type, node)
else:
raise NotImplementedError
if __name__ == "__main__":
import sys
testexpr = ast.parse(sys.argv[1], mode="eval")
print(_get_expr_type(dict(), testexpr.body))

View File

@ -0,0 +1,92 @@
import ast
from artiq.compiler import ir_values
_ast_unops = {
ast.Invert: ir_values.operators.inv,
ast.Not: ir_values.operators.not_,
ast.UAdd: ir_values.operators.pos,
ast.USub: ir_values.operators.neg
}
_ast_binops = {
ast.Add: ir_values.operators.add,
ast.Sub: ir_values.operators.sub,
ast.Mult: ir_values.operators.mul,
ast.Div: ir_values.operators.truediv,
ast.FloorDiv: ir_values.operators.floordiv,
ast.Mod: ir_values.operators.mod,
ast.Pow: ir_values.operators.pow,
ast.LShift: ir_values.operators.lshift,
ast.RShift: ir_values.operators.rshift,
ast.BitOr: ir_values.operators.or_,
ast.BitXor: ir_values.operators.xor,
ast.BitAnd: ir_values.operators.and_
}
_ast_cmps = {
ast.Eq: ir_values.operators.eq,
ast.NotEq: ir_values.operators.ne,
ast.Lt: ir_values.operators.lt,
ast.LtE: ir_values.operators.le,
ast.Gt: ir_values.operators.gt,
ast.GtE: ir_values.operators.ge
}
_ast_unfuns = {
"bool": ir_values.operators.bool,
"int": ir_values.operators.int,
"int64": ir_values.operators.int64,
"round": ir_values.operators.round,
"round64": ir_values.operators.round64,
}
class ExpressionVisitor:
def __init__(self, builder, ns):
self.builder = builder
self.ns = ns
def visit(self, node):
if isinstance(node, ast.Name):
return self.ns.load(self.builder, node.id)
elif isinstance(node, ast.NameConstant):
v = node.value
if isinstance(v, bool):
r = ir_values.VBool()
else:
raise NotImplementedError
if self.builder is not None:
r.create_constant(v)
return r
elif isinstance(node, ast.Num):
n = node.n
if isinstance(n, int):
if abs(n) < 2**31:
r = ir_values.VInt()
else:
r = ir_values.VInt(64)
else:
raise NotImplementedError
if self.builder is not None:
r.create_constant(n)
return r
elif isinstance(node, ast.UnaryOp):
return _ast_unops[type(node.op)](self.visit(node.operand), self.builder)
elif isinstance(node, ast.BinOp):
return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder)
elif isinstance(node, ast.Compare):
comparisons = []
old_comparator = self.visit(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit(comparator_a)
comparison = _ast_cmps[type(op)](old_comparator, comparator)
comparisons.append(comparison)
old_comparator = comparator
r = comparisons[0]
for comparison in comparisons[1:]:
r = ir_values.operators.and_(r, comparison)
return r
elif isinstance(node, ast.Call):
return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder)
else:
raise NotImplementedError

View File

@ -0,0 +1,65 @@
import ast
from operator import itemgetter
from copy import deepcopy
from artiq.compiler.ir_ast_body import ExpressionVisitor
class _Namespace:
def __init__(self, name_to_value):
self.name_to_value = name_to_value
def load(self, builder, name):
return self.name_to_value[name]
class _TypeScanner(ast.NodeVisitor):
def __init__(self, namespace):
self.exprv = ExpressionVisitor(None, namespace)
def visit_Assign(self, node):
val = self.exprv.visit(node.value)
n2v = self.exprv.ns.name_to_value
for target in node.targets:
if isinstance(target, ast.Name):
if target.id in n2v:
n2v[target.id].merge(val)
else:
n2v[target.id] = val
else:
raise NotImplementedError
def visit_AugAssign(self, node):
val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value))
n2v = self.exprv.ns.name_to_value
target = node.target
if isinstance(target, ast.Name):
if target.id in n2v:
n2v[target.id].merge(val)
else:
n2v[target.id] = val
else:
raise NotImplementedError
def infer_types(node):
name_to_value = dict()
while True:
prev_name_to_value = deepcopy(name_to_value)
ns = _Namespace(name_to_value)
ts = _TypeScanner(ns)
ts.visit(node)
if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()):
# no more promotions - completed
return name_to_value
if __name__ == "__main__":
testcode = """
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True
"""
n2v = infer_types(ast.parse(testcode))
for k, v in sorted(n2v.items(), key=itemgetter(0)):
print("{:10}--> {}".format(k, str(v)))

164
artiq/compiler/ir_values.py Normal file
View File

@ -0,0 +1,164 @@
from types import SimpleNamespace
from llvm import core as lc
# Integer type
class VInt:
def __init__(self, nbits=32, llvm_value=None):
self.nbits = nbits
self.llvm_value = llvm_value
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError
def create_constant(self, n):
self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n)
def create_alloca(self, builder, name):
self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name)
def o_bool(self, builder):
if builder is None:
return VBool()
else:
zero = lc.Constant.int(lc.Type.int(self.nbits), 0)
return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero))
def o_int(self, builder):
if builder is None:
return VInt()
else:
if self.nbits == 32:
return self
else:
raise NotImplementedError
o_round = o_int
def o_int64(self, builder):
if builder is None:
return VInt(64)
else:
if self.nbits == 64:
return self
else:
raise NotImplementedError
o_round64 = o_int64
def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder):
if isinstance(other, VInt):
nbits = max(self.nbits, other.nbits)
if builder is None:
return VInt(nbits)
else:
bf = getattr(builder, builder_name)
return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value))
else:
return NotImplemented
return binop_method
for _method_name, _builder_name in (
("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
if builder is None:
return VBool()
else:
return VBool(llvm_value=builder.icmp(icmp_val, self, other))
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (
("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
# Boolean type
class VBool(VInt):
def __init__(self, llvm_value=None):
VInt.__init__(self, 1, llvm_value)
def __repr__(self):
return "<VBool>"
def merge(self, other):
if not isinstance(other, VBool):
raise TypeError
def create_constant(self, b):
VInt.create_constant(self, int(b))
# Operators
def _make_unary_operator(op_name):
def op(x, builder):
try:
opf = getattr(x, "o_"+op_name)
except AttributeError:
raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__))
return opf(builder)
return op
def _make_binary_operator(op_name):
def op(l, r, builder):
try:
opf = getattr(l, "o_"+op_name)
except AttributeError:
result = NotImplemented
else:
result = opf(r, builder)
if result is NotImplemented:
try:
ropf = getattr(l, "or_"+op_name)
except AttributeError:
result = NotImplemented
else:
result = ropf(r, builder)
if result is NotImplemented:
raise TypeError("Unsupported operand types for {}: {} and {}".format(
op_name, type(l).__name__, type(r).__name__))
return result
return op
def _make_operators():
d = dict()
for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"):
d[op_name] = _make_unary_operator(op_name)
d["not_"] = _make_binary_operator("not")
for op_name in ("add", "sub", "mul",
"truediv", "floordiv", "mod",
"pow", "lshift", "rshift", "xor",
"eq", "ne", "lt", "le", "gt", "ge"):
d[op_name] = _make_binary_operator(op_name)
d["and_"] = _make_binary_operator("and")
d["or_"] = _make_binary_operator("or")
return SimpleNamespace(**d)
operators = _make_operators()