diff --git a/artiq/compiler/infer_types.py b/artiq/compiler/infer_types.py deleted file mode 100644 index 965ebf1bb..000000000 --- a/artiq/compiler/infer_types.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections import namedtuple -import ast - -from artiq.language import units - -TBool = namedtuple("TBool", "") -TFloat = namedtuple("TFloat", "") -TInt = namedtuple("TInt", "nbits") -TFraction = namedtuple("TFraction", "") - -class TypeAnnotation: - def __init__(self, t, unit=None): - self.t = t - self.unit = unit - - def __repr__(self): - r = "TypeAnnotation("+str(self.t) - if self.unit is not None: - r += " " - r += ")" - return r - - def __eq__(self, other): - return self.t == other.t and self.unit == other.unit - - def promote(self, ta): - if ta.unit != self.unit: - raise units.DimensionError - if isinstance(self.t, TBool): - if not isinstance(ta.t, TBool): - raise TypeError - elif isinstance(self.t, TFloat): - if not isinstance(ta.t, TFloat): - raise TypeError - elif isinstance(self.t, TInt): - if isinstance(ta.t, TInt): - self.t = TInt(max(self.t.nbits, ta.t.nbits)) - else: - raise TypeError - elif isinstance(self.t, TFraction): - if not isinstance(ta.t, TFraction): - raise TypeError - else: - raise TypeError - -def _get_addsub_type(l, r): - if l.unit != r.unit: - raise units.DimensionError - if isinstance(l.t, TFloat): - if isinstance(r.t, (TFloat, TInt, TFraction)): - return l - else: - raise TypeError - if isinstance(l.t, TInt) and isinstance(r.t, TInt): - return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit) - if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)): - return r - if isinstance(l.t, TFraction) and isinstance(r.t, TFloat): - return r - if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): - return l - raise TypeError - -def _get_mul_type(l, r): - unit = l.unit - if r.unit is not None: - if unit is None: - unit = r.unit - else: - raise NotImplementedError - if isinstance(l.t, TFloat): - if isinstance(r.t, (TFloat, TInt, TFraction)): - return TypeAnnotation(TFloat(), unit) - else: - raise TypeError - if isinstance(l.t, TInt) and isinstance(r.t, TInt): - return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit) - if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)): - return TypeAnnotation(r.t, unit) - if isinstance(l.t, TFraction) and isinstance(r.t, TFloat): - return TypeAnnotation(TFloat(), unit) - if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): - return TypeAnnotation(TFraction(), unit) - raise TypeError - -def _get_div_unit(l, r): - if l.unit is not None and r.unit is None: - return l.unit - elif l.unit == r.unit: - return None - else: - raise NotImplementedError - -def _get_truediv_type(l, r): - unit = _get_div_unit(l, r) - if isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction): - return TypeAnnotation(TFraction(), unit) - elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): - return TypeAnnotation(TFraction(), unit) - else: - return TypeAnnotation(TFloat(), unit) - -def _get_floordiv_type(l, r): - unit = _get_div_unit(l, r) - if isinstance(l.t, TInt) and isinstance(r.t, TInt): - return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit) - elif isinstance(l.t, (TInt, TFloat)) and isinstance(r.t, TFloat): - return TypeAnnotation(TFloat(), unit) - elif isinstance(l.t, TFloat) and isinstance(r.t, (TInt, TFloat)): - return TypeAnnotation(TFloat(), unit) - elif (isinstance(l.t, TFloat) and isinstance(r.t, TFraction)) or (isinstance(l.t, TFraction) and isinstance(r.t, TFloat)): - return TypeAnnotation(TInt(64), unit) - elif isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction): - return TypeAnnotation(TFraction(), unit) - elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): - return TypeAnnotation(TFraction(), unit) - else: - raise NotImplementedError - -def _get_call_type(sym_to_type, node): - fn = node.func.id - if fn == "bool": - return TypeAnnotation(TBool()) - elif fn == "float": - return TypeAnnotation(TFloat()) - elif fn == "int" or fn == "round": - return TypeAnnotation(TInt(32)) - elif fn == "int64" or fn == "round64": - return TypeAnnotation(TInt(64)) - elif fn == "Fraction": - return TypeAnnotation(TFraction()) - elif fn == "Quantity": - ta = _get_expr_type(sym_to_type, node.args[0]) - ta.unit = getattr(units, node.args[1].id) - return ta - else: - raise NotImplementedError - -def _get_expr_type(sym_to_type, node): - if isinstance(node, ast.NameConstant): - if isinstance(node.value, bool): - return TypeAnnotation(TBool()) - else: - raise TypeError - elif isinstance(node, ast.Num): - if isinstance(node.n, int): - nbits = 32 if abs(node.n) < 2**31 else 64 - return TypeAnnotation(TInt(nbits)) - elif isinstance(node.n, float): - return TypeAnnotation(TFloat()) - else: - raise TypeError - elif isinstance(node, ast.Name): - return sym_to_type[node.id] - elif isinstance(node, ast.UnaryOp): - return _get_expr_type(sym_to_type, node.operand) - elif isinstance(node, ast.Compare): - return TypeAnnotation(TBool()) - elif isinstance(node, ast.BinOp): - l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right) - if isinstance(node.op, (ast.Add, ast.Sub)): - return _get_addsub_type(l, r) - elif isinstance(node.op, ast.Mul): - return _get_mul_type(l, r) - elif isinstance(node.op, ast.Div): - return _get_truediv_type(l, r) - elif isinstance(node.op, ast.FloorDiv): - return _get_floordiv_type(l, r) - else: - raise NotImplementedError - elif isinstance(node, ast.Call): - return _get_call_type(sym_to_type, node) - else: - raise NotImplementedError - -if __name__ == "__main__": - import sys - testexpr = ast.parse(sys.argv[1], mode="eval") - print(_get_expr_type(dict(), testexpr.body)) diff --git a/artiq/compiler/ir_ast_body.py b/artiq/compiler/ir_ast_body.py new file mode 100644 index 000000000..58f18f1c8 --- /dev/null +++ b/artiq/compiler/ir_ast_body.py @@ -0,0 +1,92 @@ +import ast + +from artiq.compiler import ir_values + +_ast_unops = { + ast.Invert: ir_values.operators.inv, + ast.Not: ir_values.operators.not_, + ast.UAdd: ir_values.operators.pos, + ast.USub: ir_values.operators.neg +} + +_ast_binops = { + ast.Add: ir_values.operators.add, + ast.Sub: ir_values.operators.sub, + ast.Mult: ir_values.operators.mul, + ast.Div: ir_values.operators.truediv, + ast.FloorDiv: ir_values.operators.floordiv, + ast.Mod: ir_values.operators.mod, + ast.Pow: ir_values.operators.pow, + ast.LShift: ir_values.operators.lshift, + ast.RShift: ir_values.operators.rshift, + ast.BitOr: ir_values.operators.or_, + ast.BitXor: ir_values.operators.xor, + ast.BitAnd: ir_values.operators.and_ +} + +_ast_cmps = { + ast.Eq: ir_values.operators.eq, + ast.NotEq: ir_values.operators.ne, + ast.Lt: ir_values.operators.lt, + ast.LtE: ir_values.operators.le, + ast.Gt: ir_values.operators.gt, + ast.GtE: ir_values.operators.ge +} + +_ast_unfuns = { + "bool": ir_values.operators.bool, + "int": ir_values.operators.int, + "int64": ir_values.operators.int64, + "round": ir_values.operators.round, + "round64": ir_values.operators.round64, +} + +class ExpressionVisitor: + def __init__(self, builder, ns): + self.builder = builder + self.ns = ns + + def visit(self, node): + if isinstance(node, ast.Name): + return self.ns.load(self.builder, node.id) + elif isinstance(node, ast.NameConstant): + v = node.value + if isinstance(v, bool): + r = ir_values.VBool() + else: + raise NotImplementedError + if self.builder is not None: + r.create_constant(v) + return r + elif isinstance(node, ast.Num): + n = node.n + if isinstance(n, int): + if abs(n) < 2**31: + r = ir_values.VInt() + else: + r = ir_values.VInt(64) + else: + raise NotImplementedError + if self.builder is not None: + r.create_constant(n) + return r + elif isinstance(node, ast.UnaryOp): + return _ast_unops[type(node.op)](self.visit(node.operand), self.builder) + elif isinstance(node, ast.BinOp): + return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder) + elif isinstance(node, ast.Compare): + comparisons = [] + old_comparator = self.visit(node.left) + for op, comparator_a in zip(node.ops, node.comparators): + comparator = self.visit(comparator_a) + comparison = _ast_cmps[type(op)](old_comparator, comparator) + comparisons.append(comparison) + old_comparator = comparator + r = comparisons[0] + for comparison in comparisons[1:]: + r = ir_values.operators.and_(r, comparison) + return r + elif isinstance(node, ast.Call): + return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder) + else: + raise NotImplementedError diff --git a/artiq/compiler/ir_infer_types.py b/artiq/compiler/ir_infer_types.py new file mode 100644 index 000000000..afea172a5 --- /dev/null +++ b/artiq/compiler/ir_infer_types.py @@ -0,0 +1,65 @@ +import ast +from operator import itemgetter +from copy import deepcopy + +from artiq.compiler.ir_ast_body import ExpressionVisitor + +class _Namespace: + def __init__(self, name_to_value): + self.name_to_value = name_to_value + + def load(self, builder, name): + return self.name_to_value[name] + +class _TypeScanner(ast.NodeVisitor): + def __init__(self, namespace): + self.exprv = ExpressionVisitor(None, namespace) + + def visit_Assign(self, node): + val = self.exprv.visit(node.value) + n2v = self.exprv.ns.name_to_value + for target in node.targets: + if isinstance(target, ast.Name): + if target.id in n2v: + n2v[target.id].merge(val) + else: + n2v[target.id] = val + else: + raise NotImplementedError + + def visit_AugAssign(self, node): + val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value)) + n2v = self.exprv.ns.name_to_value + target = node.target + if isinstance(target, ast.Name): + if target.id in n2v: + n2v[target.id].merge(val) + else: + n2v[target.id] = val + else: + raise NotImplementedError + +def infer_types(node): + name_to_value = dict() + while True: + prev_name_to_value = deepcopy(name_to_value) + ns = _Namespace(name_to_value) + ts = _TypeScanner(ns) + ts.visit(node) + if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()): + # no more promotions - completed + return name_to_value + +if __name__ == "__main__": + testcode = """ +a = 2 # promoted later to int64 +b = a + 1 # initially int32, becomes int64 after a is promoted +c = b//2 # initially int32, becomes int64 after b is promoted +d = 4 # stays int32 +x = int64(7) +a += x # promotes a to int64 +foo = True +""" + n2v = infer_types(ast.parse(testcode)) + for k, v in sorted(n2v.items(), key=itemgetter(0)): + print("{:10}--> {}".format(k, str(v))) diff --git a/artiq/compiler/ir_values.py b/artiq/compiler/ir_values.py new file mode 100644 index 000000000..cdb331c21 --- /dev/null +++ b/artiq/compiler/ir_values.py @@ -0,0 +1,164 @@ +from types import SimpleNamespace + +from llvm import core as lc + +# Integer type + +class VInt: + def __init__(self, nbits=32, llvm_value=None): + self.nbits = nbits + self.llvm_value = llvm_value + + def __repr__(self): + return "".format(self.nbits) + + def same_type(self, other): + return isinstance(other, VInt) and other.nbits == self.nbits + + def merge(self, other): + if isinstance(other, VInt) and not isinstance(other, VBool): + if other.nbits > self.nbits: + self.nbits = other.nbits + else: + raise TypeError + + def create_constant(self, n): + self.llvm_value = lc.Constant.int(lc.Type.int(self.nbits), n) + + def create_alloca(self, builder, name): + self.llvm_value = builder.alloca(lc.Type.int(self.nbits), name=name) + + def o_bool(self, builder): + if builder is None: + return VBool() + else: + zero = lc.Constant.int(lc.Type.int(self.nbits), 0) + return VBool(llvm_value=builder.icmp(lc.ICMP_NE, self.llvm_value, zero)) + + def o_int(self, builder): + if builder is None: + return VInt() + else: + if self.nbits == 32: + return self + else: + raise NotImplementedError + o_round = o_int + + def o_int64(self, builder): + if builder is None: + return VInt(64) + else: + if self.nbits == 64: + return self + else: + raise NotImplementedError + o_round64 = o_int64 + +def _make_vint_binop_method(builder_name): + def binop_method(self, other, builder): + if isinstance(other, VInt): + nbits = max(self.nbits, other.nbits) + if builder is None: + return VInt(nbits) + else: + bf = getattr(builder, builder_name) + return VInt(nbits, llvm_value=bf(self.llvm_value, other.llvm_value)) + else: + return NotImplemented + return binop_method + +for _method_name, _builder_name in ( + ("o_add", "add"), + ("o_sub", "sub"), + ("o_mul", "mul"), + ("o_floordiv", "sdiv"), + ("o_mod", "srem"), + ("o_and", "and_"), + ("o_xor", "xor"), + ("o_or", "or_")): + setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) + +def _make_vint_cmp_method(icmp_val): + def cmp_method(self, other, builder): + if isinstance(other, VInt): + if builder is None: + return VBool() + else: + return VBool(llvm_value=builder.icmp(icmp_val, self, other)) + else: + return NotImplemented + return cmp_method + +for _method_name, _icmp_val in ( + ("o_eq", lc.ICMP_EQ), + ("o_ne", lc.ICMP_NE), + ("o_lt", lc.ICMP_SLT), + ("o_le", lc.ICMP_SLE), + ("o_gt", lc.ICMP_SGT), + ("o_ge", lc.ICMP_SGE)): + setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) + +# Boolean type + +class VBool(VInt): + def __init__(self, llvm_value=None): + VInt.__init__(self, 1, llvm_value) + + def __repr__(self): + return "" + + def merge(self, other): + if not isinstance(other, VBool): + raise TypeError + + def create_constant(self, b): + VInt.create_constant(self, int(b)) + +# Operators + +def _make_unary_operator(op_name): + def op(x, builder): + try: + opf = getattr(x, "o_"+op_name) + except AttributeError: + raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__)) + return opf(builder) + return op + +def _make_binary_operator(op_name): + def op(l, r, builder): + try: + opf = getattr(l, "o_"+op_name) + except AttributeError: + result = NotImplemented + else: + result = opf(r, builder) + if result is NotImplemented: + try: + ropf = getattr(l, "or_"+op_name) + except AttributeError: + result = NotImplemented + else: + result = ropf(r, builder) + if result is NotImplemented: + raise TypeError("Unsupported operand types for {}: {} and {}".format( + op_name, type(l).__name__, type(r).__name__)) + return result + return op + +def _make_operators(): + d = dict() + for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"): + d[op_name] = _make_unary_operator(op_name) + d["not_"] = _make_binary_operator("not") + for op_name in ("add", "sub", "mul", + "truediv", "floordiv", "mod", + "pow", "lshift", "rshift", "xor", + "eq", "ne", "lt", "le", "gt", "ge"): + d[op_name] = _make_binary_operator(op_name) + d["and_"] = _make_binary_operator("and") + d["or_"] = _make_binary_operator("or") + return SimpleNamespace(**d) + +operators = _make_operators()