diff --git a/artiq/compiler/infer_types.py b/artiq/compiler/infer_types.py index 7acebb65a..965ebf1bb 100644 --- a/artiq/compiler/infer_types.py +++ b/artiq/compiler/infer_types.py @@ -1,16 +1,11 @@ from collections import namedtuple -from fractions import gcd import ast from artiq.language import units -def _lcm(a, b): - return a*b//gcd(a, b) - TBool = namedtuple("TBool", "") TFloat = namedtuple("TFloat", "") TInt = namedtuple("TInt", "nbits") -TFractionCD = namedtuple("TFractionCD", "denominator") TFraction = namedtuple("TFraction", "") class TypeAnnotation: @@ -25,6 +20,9 @@ class TypeAnnotation: r += ")" return r + def __eq__(self, other): + return self.t == other.t and self.unit == other.unit + def promote(self, ta): if ta.unit != self.unit: raise units.DimensionError @@ -37,25 +35,107 @@ class TypeAnnotation: elif isinstance(self.t, TInt): if isinstance(ta.t, TInt): self.t = TInt(max(self.t.nbits, ta.t.nbits)) - elif isinstance(ta.t, (TFractionCD, TFraction)): - self.t = ta.t - else: - raise TypeError - elif isinstance(self.t, TFractionCD): - if isinstance(ta.t, TInt): - pass - elif isinstance(ta.t, TFractionCD): - self.t = TFractionCD(_lcm(self.t.denominator, ta.t.denominator)) - elif isinstance(ta.t, TFraction): - self.t = TFraction() else: raise TypeError elif isinstance(self.t, TFraction): - if not isinstance(ta.t, (TInt, TFractionCD, TFraction)): + if not isinstance(ta.t, TFraction): raise TypeError else: raise TypeError +def _get_addsub_type(l, r): + if l.unit != r.unit: + raise units.DimensionError + if isinstance(l.t, TFloat): + if isinstance(r.t, (TFloat, TInt, TFraction)): + return l + else: + raise TypeError + if isinstance(l.t, TInt) and isinstance(r.t, TInt): + return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit) + if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)): + return r + if isinstance(l.t, TFraction) and isinstance(r.t, TFloat): + return r + if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): + return l + raise TypeError + +def _get_mul_type(l, r): + unit = l.unit + if r.unit is not None: + if unit is None: + unit = r.unit + else: + raise NotImplementedError + if isinstance(l.t, TFloat): + if isinstance(r.t, (TFloat, TInt, TFraction)): + return TypeAnnotation(TFloat(), unit) + else: + raise TypeError + if isinstance(l.t, TInt) and isinstance(r.t, TInt): + return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit) + if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)): + return TypeAnnotation(r.t, unit) + if isinstance(l.t, TFraction) and isinstance(r.t, TFloat): + return TypeAnnotation(TFloat(), unit) + if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): + return TypeAnnotation(TFraction(), unit) + raise TypeError + +def _get_div_unit(l, r): + if l.unit is not None and r.unit is None: + return l.unit + elif l.unit == r.unit: + return None + else: + raise NotImplementedError + +def _get_truediv_type(l, r): + unit = _get_div_unit(l, r) + if isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction): + return TypeAnnotation(TFraction(), unit) + elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): + return TypeAnnotation(TFraction(), unit) + else: + return TypeAnnotation(TFloat(), unit) + +def _get_floordiv_type(l, r): + unit = _get_div_unit(l, r) + if isinstance(l.t, TInt) and isinstance(r.t, TInt): + return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit) + elif isinstance(l.t, (TInt, TFloat)) and isinstance(r.t, TFloat): + return TypeAnnotation(TFloat(), unit) + elif isinstance(l.t, TFloat) and isinstance(r.t, (TInt, TFloat)): + return TypeAnnotation(TFloat(), unit) + elif (isinstance(l.t, TFloat) and isinstance(r.t, TFraction)) or (isinstance(l.t, TFraction) and isinstance(r.t, TFloat)): + return TypeAnnotation(TInt(64), unit) + elif isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction): + return TypeAnnotation(TFraction(), unit) + elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)): + return TypeAnnotation(TFraction(), unit) + else: + raise NotImplementedError + +def _get_call_type(sym_to_type, node): + fn = node.func.id + if fn == "bool": + return TypeAnnotation(TBool()) + elif fn == "float": + return TypeAnnotation(TFloat()) + elif fn == "int" or fn == "round": + return TypeAnnotation(TInt(32)) + elif fn == "int64" or fn == "round64": + return TypeAnnotation(TInt(64)) + elif fn == "Fraction": + return TypeAnnotation(TFraction()) + elif fn == "Quantity": + ta = _get_expr_type(sym_to_type, node.args[0]) + ta.unit = getattr(units, node.args[1].id) + return ta + else: + raise NotImplementedError + def _get_expr_type(sym_to_type, node): if isinstance(node, ast.NameConstant): if isinstance(node.value, bool): @@ -79,52 +159,17 @@ def _get_expr_type(sym_to_type, node): elif isinstance(node, ast.BinOp): l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right) if isinstance(node.op, (ast.Add, ast.Sub)): - if l.unit != r.unit: - raise units.DimensionError - if isinstance(l.t, TFloat): - if isinstance(r.t, (TFloat, TInt, TFraction, TFractionCD)): - return l - else: - raise TypeError - if isinstance(l.t, TInt) and isinstance(r.t, TInt): - return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit) - if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction, TFractionCD)): - return r - if isinstance(l.t, (TFractionCD, TFraction)) and isinstance(r.t, TFloat): - return r - if isinstance(l.t, TFractionCD) and isinstance(r.t, TInt): - return l - if isinstance(l.t, TFractionCD) and isinstance(r.t, TFractionCD): - return TypeAnnotation(TFractionCD(_lcm(l.t.denominator, r.t.denominator)), l.unit) - if isinstance(l.t, TFractionCD) and isinstance(r.t, TFraction): - return TypeAnnotation(TFraction()) - if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFractionCD, TFraction)): - return l - raise TypeError + return _get_addsub_type(l, r) + elif isinstance(node.op, ast.Mul): + return _get_mul_type(l, r) + elif isinstance(node.op, ast.Div): + return _get_truediv_type(l, r) + elif isinstance(node.op, ast.FloorDiv): + return _get_floordiv_type(l, r) else: raise NotImplementedError elif isinstance(node, ast.Call): - if node.func.id == "bool": - return TypeAnnotation(TBool()) - elif node.func.id == "float": - return TypeAnnotation(TFloat()) - elif node.func.id == "int": - return TypeAnnotation(TInt(32)) - elif node.func.id == "int64": - return TypeAnnotation(TInt(64)) - elif node.func.id == "Fraction": - if len(node.args) == 2 and isinstance(node.args[1], ast.Num): - if not isinstance(node.args[1].n, int): - raise TypeError - return TypeAnnotation(TFractionCD(node.args[1].n)) - else: - return TypeAnnotation(TFraction()) - elif node.func.id == "Quantity": - ta = _get_expr_type(sym_to_type, node.args[0]) - ta.unit = getattr(units, node.args[1].id) - return ta - else: - raise NotImplementedError + return _get_call_type(sym_to_type, node) else: raise NotImplementedError