forked from M-Labs/artiq
compiler: new architecture for type inference and LLVM code emission
This commit is contained in:
parent
9189ad5fab
commit
3e4cbba018
|
@ -1,179 +0,0 @@
|
|||
from collections import namedtuple
|
||||
import ast
|
||||
|
||||
from artiq.language import units
|
||||
|
||||
TBool = namedtuple("TBool", "")
|
||||
TFloat = namedtuple("TFloat", "")
|
||||
TInt = namedtuple("TInt", "nbits")
|
||||
TFraction = namedtuple("TFraction", "")
|
||||
|
||||
class TypeAnnotation:
|
||||
def __init__(self, t, unit=None):
|
||||
self.t = t
|
||||
self.unit = unit
|
||||
|
||||
def __repr__(self):
|
||||
r = "TypeAnnotation("+str(self.t)
|
||||
if self.unit is not None:
|
||||
r += " <unit:"+str(self.unit.name)+">"
|
||||
r += ")"
|
||||
return r
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.t == other.t and self.unit == other.unit
|
||||
|
||||
def promote(self, ta):
|
||||
if ta.unit != self.unit:
|
||||
raise units.DimensionError
|
||||
if isinstance(self.t, TBool):
|
||||
if not isinstance(ta.t, TBool):
|
||||
raise TypeError
|
||||
elif isinstance(self.t, TFloat):
|
||||
if not isinstance(ta.t, TFloat):
|
||||
raise TypeError
|
||||
elif isinstance(self.t, TInt):
|
||||
if isinstance(ta.t, TInt):
|
||||
self.t = TInt(max(self.t.nbits, ta.t.nbits))
|
||||
else:
|
||||
raise TypeError
|
||||
elif isinstance(self.t, TFraction):
|
||||
if not isinstance(ta.t, TFraction):
|
||||
raise TypeError
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _get_addsub_type(l, r):
|
||||
if l.unit != r.unit:
|
||||
raise units.DimensionError
|
||||
if isinstance(l.t, TFloat):
|
||||
if isinstance(r.t, (TFloat, TInt, TFraction)):
|
||||
return l
|
||||
else:
|
||||
raise TypeError
|
||||
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit)
|
||||
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
|
||||
return r
|
||||
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
|
||||
return r
|
||||
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||
return l
|
||||
raise TypeError
|
||||
|
||||
def _get_mul_type(l, r):
|
||||
unit = l.unit
|
||||
if r.unit is not None:
|
||||
if unit is None:
|
||||
unit = r.unit
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if isinstance(l.t, TFloat):
|
||||
if isinstance(r.t, (TFloat, TInt, TFraction)):
|
||||
return TypeAnnotation(TFloat(), unit)
|
||||
else:
|
||||
raise TypeError
|
||||
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
|
||||
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
|
||||
return TypeAnnotation(r.t, unit)
|
||||
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
|
||||
return TypeAnnotation(TFloat(), unit)
|
||||
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||
return TypeAnnotation(TFraction(), unit)
|
||||
raise TypeError
|
||||
|
||||
def _get_div_unit(l, r):
|
||||
if l.unit is not None and r.unit is None:
|
||||
return l.unit
|
||||
elif l.unit == r.unit:
|
||||
return None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_truediv_type(l, r):
|
||||
unit = _get_div_unit(l, r)
|
||||
if isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
|
||||
return TypeAnnotation(TFraction(), unit)
|
||||
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||
return TypeAnnotation(TFraction(), unit)
|
||||
else:
|
||||
return TypeAnnotation(TFloat(), unit)
|
||||
|
||||
def _get_floordiv_type(l, r):
|
||||
unit = _get_div_unit(l, r)
|
||||
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
|
||||
elif isinstance(l.t, (TInt, TFloat)) and isinstance(r.t, TFloat):
|
||||
return TypeAnnotation(TFloat(), unit)
|
||||
elif isinstance(l.t, TFloat) and isinstance(r.t, (TInt, TFloat)):
|
||||
return TypeAnnotation(TFloat(), unit)
|
||||
elif (isinstance(l.t, TFloat) and isinstance(r.t, TFraction)) or (isinstance(l.t, TFraction) and isinstance(r.t, TFloat)):
|
||||
return TypeAnnotation(TInt(64), unit)
|
||||
elif isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
|
||||
return TypeAnnotation(TFraction(), unit)
|
||||
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||
return TypeAnnotation(TFraction(), unit)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_call_type(sym_to_type, node):
|
||||
fn = node.func.id
|
||||
if fn == "bool":
|
||||
return TypeAnnotation(TBool())
|
||||
elif fn == "float":
|
||||
return TypeAnnotation(TFloat())
|
||||
elif fn == "int" or fn == "round":
|
||||
return TypeAnnotation(TInt(32))
|
||||
elif fn == "int64" or fn == "round64":
|
||||
return TypeAnnotation(TInt(64))
|
||||
elif fn == "Fraction":
|
||||
return TypeAnnotation(TFraction())
|
||||
elif fn == "Quantity":
|
||||
ta = _get_expr_type(sym_to_type, node.args[0])
|
||||
ta.unit = getattr(units, node.args[1].id)
|
||||
return ta
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_expr_type(sym_to_type, node):
|
||||
if isinstance(node, ast.NameConstant):
|
||||
if isinstance(node.value, bool):
|
||||
return TypeAnnotation(TBool())
|
||||
else:
|
||||
raise TypeError
|
||||
elif isinstance(node, ast.Num):
|
||||
if isinstance(node.n, int):
|
||||
nbits = 32 if abs(node.n) < 2**31 else 64
|
||||
return TypeAnnotation(TInt(nbits))
|
||||
elif isinstance(node.n, float):
|
||||
return TypeAnnotation(TFloat())
|
||||
else:
|
||||
raise TypeError
|
||||
elif isinstance(node, ast.Name):
|
||||
return sym_to_type[node.id]
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
return _get_expr_type(sym_to_type, node.operand)
|
||||
elif isinstance(node, ast.Compare):
|
||||
return TypeAnnotation(TBool())
|
||||
elif isinstance(node, ast.BinOp):
|
||||
l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right)
|
||||
if isinstance(node.op, (ast.Add, ast.Sub)):
|
||||
return _get_addsub_type(l, r)
|
||||
elif isinstance(node.op, ast.Mul):
|
||||
return _get_mul_type(l, r)
|
||||
elif isinstance(node.op, ast.Div):
|
||||
return _get_truediv_type(l, r)
|
||||
elif isinstance(node.op, ast.FloorDiv):
|
||||
return _get_floordiv_type(l, r)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif isinstance(node, ast.Call):
|
||||
return _get_call_type(sym_to_type, node)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
testexpr = ast.parse(sys.argv[1], mode="eval")
|
||||
print(_get_expr_type(dict(), testexpr.body))
|
|
@ -0,0 +1,92 @@
|
|||
import ast
|
||||
|
||||
from artiq.compiler import ir_values
|
||||
|
||||
_ast_unops = {
|
||||
ast.Invert: ir_values.operators.inv,
|
||||
ast.Not: ir_values.operators.not_,
|
||||
ast.UAdd: ir_values.operators.pos,
|
||||
ast.USub: ir_values.operators.neg
|
||||
}
|
||||
|
||||
_ast_binops = {
|
||||
ast.Add: ir_values.operators.add,
|
||||
ast.Sub: ir_values.operators.sub,
|
||||
ast.Mult: ir_values.operators.mul,
|
||||
ast.Div: ir_values.operators.truediv,
|
||||
ast.FloorDiv: ir_values.operators.floordiv,
|
||||
ast.Mod: ir_values.operators.mod,
|
||||
ast.Pow: ir_values.operators.pow,
|
||||
ast.LShift: ir_values.operators.lshift,
|
||||
ast.RShift: ir_values.operators.rshift,
|
||||
ast.BitOr: ir_values.operators.or_,
|
||||
ast.BitXor: ir_values.operators.xor,
|
||||
ast.BitAnd: ir_values.operators.and_
|
||||
}
|
||||
|
||||
_ast_cmps = {
|
||||
ast.Eq: ir_values.operators.eq,
|
||||
ast.NotEq: ir_values.operators.ne,
|
||||
ast.Lt: ir_values.operators.lt,
|
||||
ast.LtE: ir_values.operators.le,
|
||||
ast.Gt: ir_values.operators.gt,
|
||||
ast.GtE: ir_values.operators.ge
|
||||
}
|
||||
|
||||
_ast_unfuns = {
|
||||
"bool": ir_values.operators.bool,
|
||||
"int": ir_values.operators.int,
|
||||
"int64": ir_values.operators.int64,
|
||||
"round": ir_values.operators.round,
|
||||
"round64": ir_values.operators.round64,
|
||||
}
|
||||
|
||||
class ExpressionVisitor:
|
||||
def __init__(self, builder, ns):
|
||||
self.builder = builder
|
||||
self.ns = ns
|
||||
|
||||
def visit(self, node):
|
||||
if isinstance(node, ast.Name):
|
||||
return self.ns.load(self.builder, node.id)
|
||||
elif isinstance(node, ast.NameConstant):
|
||||
v = node.value
|
||||
if isinstance(v, bool):
|
||||
r = ir_values.VBool()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(v)
|
||||
return r
|
||||
elif isinstance(node, ast.Num):
|
||||
n = node.n
|
||||
if isinstance(n, int):
|
||||
if abs(n) < 2**31:
|
||||
r = ir_values.VInt()
|
||||
else:
|
||||
r = ir_values.VInt(64)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(n)
|
||||
return r
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
return _ast_unops[type(node.op)](self.visit(node.operand), self.builder)
|
||||
elif isinstance(node, ast.BinOp):
|
||||
return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder)
|
||||
elif isinstance(node, ast.Compare):
|
||||
comparisons = []
|
||||
old_comparator = self.visit(node.left)
|
||||
for op, comparator_a in zip(node.ops, node.comparators):
|
||||
comparator = self.visit(comparator_a)
|
||||
comparison = _ast_cmps[type(op)](old_comparator, comparator)
|
||||
comparisons.append(comparison)
|
||||
old_comparator = comparator
|
||||
r = comparisons[0]
|
||||
for comparison in comparisons[1:]:
|
||||
r = ir_values.operators.and_(r, comparison)
|
||||
return r
|
||||
elif isinstance(node, ast.Call):
|
||||
return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,65 @@
|
|||
import ast
|
||||
from operator import itemgetter
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.compiler.ir_ast_body import ExpressionVisitor
|
||||
|
||||
class _Namespace:
|
||||
def __init__(self, name_to_value):
|
||||
self.name_to_value = name_to_value
|
||||
|
||||
def load(self, builder, name):
|
||||
return self.name_to_value[name]
|
||||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
def __init__(self, namespace):
|
||||
self.exprv = ExpressionVisitor(None, namespace)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit(node.value)
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in n2v:
|
||||
n2v[target.id].merge(val)
|
||||
else:
|
||||
n2v[target.id] = val
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
target = node.target
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in n2v:
|
||||
n2v[target.id].merge(val)
|
||||
else:
|
||||
n2v[target.id] = val
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_types(node):
|
||||
name_to_value = dict()
|
||||
while True:
|
||||
prev_name_to_value = deepcopy(name_to_value)
|
||||
ns = _Namespace(name_to_value)
|
||||
ts = _TypeScanner(ns)
|
||||
ts.visit(node)
|
||||
if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()):
|
||||
# no more promotions - completed
|
||||
return name_to_value
|
||||
|
||||
if __name__ == "__main__":
|
||||
testcode = """
|
||||
a = 2 # promoted later to int64
|
||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||
d = 4 # stays int32
|
||||
x = int64(7)
|
||||
a += x # promotes a to int64
|
||||
foo = True
|
||||
"""
|
||||
n2v = infer_types(ast.parse(testcode))
|
||||
for k, v in sorted(n2v.items(), key=itemgetter(0)):
|
||||
print("{:10}--> {}".format(k, str(v)))
|
|
@ -0,0 +1,164 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
# Integer type
|
||||
|
||||
class VInt:
|
||||
def __init__(self, nbits=32, llvm_value=None):
|
||||
self.nbits = nbits
|
||||
self.llvm_value = llvm_value
|
||||
|
||||
def __repr__(self):
|
||||
return "<VInt:{}>".format(self.nbits)
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VInt) and other.nbits == self.nbits
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VInt) and not isinstance(other, VBool):
|
||||
if other.nbits > self.nbits:
|
||||
self.nbits = other.nbits
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def create_constant(self, n):
|
||||
self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n)
|
||||
|
||||
def create_alloca(self, builder, name):
|
||||
self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name)
|
||||
|
||||
def o_bool(self, builder):
|
||||
if builder is None:
|
||||
return VBool()
|
||||
else:
|
||||
zero = lc.Constant.int(lc.Type.int(self.nbits), 0)
|
||||
return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero))
|
||||
|
||||
def o_int(self, builder):
|
||||
if builder is None:
|
||||
return VInt()
|
||||
else:
|
||||
if self.nbits == 32:
|
||||
return self
|
||||
else:
|
||||
raise NotImplementedError
|
||||
o_round = o_int
|
||||
|
||||
def o_int64(self, builder):
|
||||
if builder is None:
|
||||
return VInt(64)
|
||||
else:
|
||||
if self.nbits == 64:
|
||||
return self
|
||||
else:
|
||||
raise NotImplementedError
|
||||
o_round64 = o_int64
|
||||
|
||||
def _make_vint_binop_method(builder_name):
|
||||
def binop_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
nbits = max(self.nbits, other.nbits)
|
||||
if builder is None:
|
||||
return VInt(nbits)
|
||||
else:
|
||||
bf = getattr(builder, builder_name)
|
||||
return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value))
|
||||
else:
|
||||
return NotImplemented
|
||||
return binop_method
|
||||
|
||||
for _method_name, _builder_name in (
|
||||
("o_add", "add"),
|
||||
("o_sub", "sub"),
|
||||
("o_mul", "mul"),
|
||||
("o_floordiv", "sdiv"),
|
||||
("o_mod", "srem"),
|
||||
("o_and", "and_"),
|
||||
("o_xor", "xor"),
|
||||
("o_or", "or_")):
|
||||
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
|
||||
|
||||
def _make_vint_cmp_method(icmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
if builder is None:
|
||||
return VBool()
|
||||
else:
|
||||
return VBool(llvm_value=builder.icmp(icmp_val, self, other))
|
||||
else:
|
||||
return NotImplemented
|
||||
return cmp_method
|
||||
|
||||
for _method_name, _icmp_val in (
|
||||
("o_eq", lc.ICMP_EQ),
|
||||
("o_ne", lc.ICMP_NE),
|
||||
("o_lt", lc.ICMP_SLT),
|
||||
("o_le", lc.ICMP_SLE),
|
||||
("o_gt", lc.ICMP_SGT),
|
||||
("o_ge", lc.ICMP_SGE)):
|
||||
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
||||
|
||||
# Boolean type
|
||||
|
||||
class VBool(VInt):
|
||||
def __init__(self, llvm_value=None):
|
||||
VInt.__init__(self, 1, llvm_value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VBool>"
|
||||
|
||||
def merge(self, other):
|
||||
if not isinstance(other, VBool):
|
||||
raise TypeError
|
||||
|
||||
def create_constant(self, b):
|
||||
VInt.create_constant(self, int(b))
|
||||
|
||||
# Operators
|
||||
|
||||
def _make_unary_operator(op_name):
|
||||
def op(x, builder):
|
||||
try:
|
||||
opf = getattr(x, "o_"+op_name)
|
||||
except AttributeError:
|
||||
raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__))
|
||||
return opf(builder)
|
||||
return op
|
||||
|
||||
def _make_binary_operator(op_name):
|
||||
def op(l, r, builder):
|
||||
try:
|
||||
opf = getattr(l, "o_"+op_name)
|
||||
except AttributeError:
|
||||
result = NotImplemented
|
||||
else:
|
||||
result = opf(r, builder)
|
||||
if result is NotImplemented:
|
||||
try:
|
||||
ropf = getattr(l, "or_"+op_name)
|
||||
except AttributeError:
|
||||
result = NotImplemented
|
||||
else:
|
||||
result = ropf(r, builder)
|
||||
if result is NotImplemented:
|
||||
raise TypeError("Unsupported operand types for {}: {} and {}".format(
|
||||
op_name, type(l).__name__, type(r).__name__))
|
||||
return result
|
||||
return op
|
||||
|
||||
def _make_operators():
|
||||
d = dict()
|
||||
for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"):
|
||||
d[op_name] = _make_unary_operator(op_name)
|
||||
d["not_"] = _make_binary_operator("not")
|
||||
for op_name in ("add", "sub", "mul",
|
||||
"truediv", "floordiv", "mod",
|
||||
"pow", "lshift", "rshift", "xor",
|
||||
"eq", "ne", "lt", "le", "gt", "ge"):
|
||||
d[op_name] = _make_binary_operator(op_name)
|
||||
d["and_"] = _make_binary_operator("and")
|
||||
d["or_"] = _make_binary_operator("or")
|
||||
return SimpleNamespace(**d)
|
||||
|
||||
operators = _make_operators()
|
Loading…
Reference in New Issue