forked from M-Labs/artiq
infer_type: better rules
This commit is contained in:
parent
232092166e
commit
02798d1996
|
@ -1,16 +1,11 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from fractions import gcd
|
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
|
||||||
def _lcm(a, b):
|
|
||||||
return a*b//gcd(a, b)
|
|
||||||
|
|
||||||
TBool = namedtuple("TBool", "")
|
TBool = namedtuple("TBool", "")
|
||||||
TFloat = namedtuple("TFloat", "")
|
TFloat = namedtuple("TFloat", "")
|
||||||
TInt = namedtuple("TInt", "nbits")
|
TInt = namedtuple("TInt", "nbits")
|
||||||
TFractionCD = namedtuple("TFractionCD", "denominator")
|
|
||||||
TFraction = namedtuple("TFraction", "")
|
TFraction = namedtuple("TFraction", "")
|
||||||
|
|
||||||
class TypeAnnotation:
|
class TypeAnnotation:
|
||||||
|
@ -25,6 +20,9 @@ class TypeAnnotation:
|
||||||
r += ")"
|
r += ")"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.t == other.t and self.unit == other.unit
|
||||||
|
|
||||||
def promote(self, ta):
|
def promote(self, ta):
|
||||||
if ta.unit != self.unit:
|
if ta.unit != self.unit:
|
||||||
raise units.DimensionError
|
raise units.DimensionError
|
||||||
|
@ -37,25 +35,107 @@ class TypeAnnotation:
|
||||||
elif isinstance(self.t, TInt):
|
elif isinstance(self.t, TInt):
|
||||||
if isinstance(ta.t, TInt):
|
if isinstance(ta.t, TInt):
|
||||||
self.t = TInt(max(self.t.nbits, ta.t.nbits))
|
self.t = TInt(max(self.t.nbits, ta.t.nbits))
|
||||||
elif isinstance(ta.t, (TFractionCD, TFraction)):
|
|
||||||
self.t = ta.t
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
elif isinstance(self.t, TFractionCD):
|
|
||||||
if isinstance(ta.t, TInt):
|
|
||||||
pass
|
|
||||||
elif isinstance(ta.t, TFractionCD):
|
|
||||||
self.t = TFractionCD(_lcm(self.t.denominator, ta.t.denominator))
|
|
||||||
elif isinstance(ta.t, TFraction):
|
|
||||||
self.t = TFraction()
|
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
elif isinstance(self.t, TFraction):
|
elif isinstance(self.t, TFraction):
|
||||||
if not isinstance(ta.t, (TInt, TFractionCD, TFraction)):
|
if not isinstance(ta.t, TFraction):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
def _get_addsub_type(l, r):
|
||||||
|
if l.unit != r.unit:
|
||||||
|
raise units.DimensionError
|
||||||
|
if isinstance(l.t, TFloat):
|
||||||
|
if isinstance(r.t, (TFloat, TInt, TFraction)):
|
||||||
|
return l
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||||
|
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit)
|
||||||
|
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
|
||||||
|
return r
|
||||||
|
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
|
||||||
|
return r
|
||||||
|
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||||
|
return l
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def _get_mul_type(l, r):
|
||||||
|
unit = l.unit
|
||||||
|
if r.unit is not None:
|
||||||
|
if unit is None:
|
||||||
|
unit = r.unit
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if isinstance(l.t, TFloat):
|
||||||
|
if isinstance(r.t, (TFloat, TInt, TFraction)):
|
||||||
|
return TypeAnnotation(TFloat(), unit)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||||
|
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
|
||||||
|
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction)):
|
||||||
|
return TypeAnnotation(r.t, unit)
|
||||||
|
if isinstance(l.t, TFraction) and isinstance(r.t, TFloat):
|
||||||
|
return TypeAnnotation(TFloat(), unit)
|
||||||
|
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||||
|
return TypeAnnotation(TFraction(), unit)
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def _get_div_unit(l, r):
|
||||||
|
if l.unit is not None and r.unit is None:
|
||||||
|
return l.unit
|
||||||
|
elif l.unit == r.unit:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_truediv_type(l, r):
|
||||||
|
unit = _get_div_unit(l, r)
|
||||||
|
if isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
|
||||||
|
return TypeAnnotation(TFraction(), unit)
|
||||||
|
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||||
|
return TypeAnnotation(TFraction(), unit)
|
||||||
|
else:
|
||||||
|
return TypeAnnotation(TFloat(), unit)
|
||||||
|
|
||||||
|
def _get_floordiv_type(l, r):
|
||||||
|
unit = _get_div_unit(l, r)
|
||||||
|
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
||||||
|
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), unit)
|
||||||
|
elif isinstance(l.t, (TInt, TFloat)) and isinstance(r.t, TFloat):
|
||||||
|
return TypeAnnotation(TFloat(), unit)
|
||||||
|
elif isinstance(l.t, TFloat) and isinstance(r.t, (TInt, TFloat)):
|
||||||
|
return TypeAnnotation(TFloat(), unit)
|
||||||
|
elif (isinstance(l.t, TFloat) and isinstance(r.t, TFraction)) or (isinstance(l.t, TFraction) and isinstance(r.t, TFloat)):
|
||||||
|
return TypeAnnotation(TInt(64), unit)
|
||||||
|
elif isinstance(l.t, (TInt, TFraction)) and isinstance(r.t, TFraction):
|
||||||
|
return TypeAnnotation(TFraction(), unit)
|
||||||
|
elif isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFraction)):
|
||||||
|
return TypeAnnotation(TFraction(), unit)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_call_type(sym_to_type, node):
|
||||||
|
fn = node.func.id
|
||||||
|
if fn == "bool":
|
||||||
|
return TypeAnnotation(TBool())
|
||||||
|
elif fn == "float":
|
||||||
|
return TypeAnnotation(TFloat())
|
||||||
|
elif fn == "int" or fn == "round":
|
||||||
|
return TypeAnnotation(TInt(32))
|
||||||
|
elif fn == "int64" or fn == "round64":
|
||||||
|
return TypeAnnotation(TInt(64))
|
||||||
|
elif fn == "Fraction":
|
||||||
|
return TypeAnnotation(TFraction())
|
||||||
|
elif fn == "Quantity":
|
||||||
|
ta = _get_expr_type(sym_to_type, node.args[0])
|
||||||
|
ta.unit = getattr(units, node.args[1].id)
|
||||||
|
return ta
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_expr_type(sym_to_type, node):
|
def _get_expr_type(sym_to_type, node):
|
||||||
if isinstance(node, ast.NameConstant):
|
if isinstance(node, ast.NameConstant):
|
||||||
if isinstance(node.value, bool):
|
if isinstance(node.value, bool):
|
||||||
|
@ -79,52 +159,17 @@ def _get_expr_type(sym_to_type, node):
|
||||||
elif isinstance(node, ast.BinOp):
|
elif isinstance(node, ast.BinOp):
|
||||||
l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right)
|
l, r = _get_expr_type(sym_to_type, node.left), _get_expr_type(sym_to_type, node.right)
|
||||||
if isinstance(node.op, (ast.Add, ast.Sub)):
|
if isinstance(node.op, (ast.Add, ast.Sub)):
|
||||||
if l.unit != r.unit:
|
return _get_addsub_type(l, r)
|
||||||
raise units.DimensionError
|
elif isinstance(node.op, ast.Mul):
|
||||||
if isinstance(l.t, TFloat):
|
return _get_mul_type(l, r)
|
||||||
if isinstance(r.t, (TFloat, TInt, TFraction, TFractionCD)):
|
elif isinstance(node.op, ast.Div):
|
||||||
return l
|
return _get_truediv_type(l, r)
|
||||||
else:
|
elif isinstance(node.op, ast.FloorDiv):
|
||||||
raise TypeError
|
return _get_floordiv_type(l, r)
|
||||||
if isinstance(l.t, TInt) and isinstance(r.t, TInt):
|
|
||||||
return TypeAnnotation(TInt(max(l.t.nbits, r.t.nbits)), l.unit)
|
|
||||||
if isinstance(l.t, TInt) and isinstance(r.t, (TFloat, TFraction, TFractionCD)):
|
|
||||||
return r
|
|
||||||
if isinstance(l.t, (TFractionCD, TFraction)) and isinstance(r.t, TFloat):
|
|
||||||
return r
|
|
||||||
if isinstance(l.t, TFractionCD) and isinstance(r.t, TInt):
|
|
||||||
return l
|
|
||||||
if isinstance(l.t, TFractionCD) and isinstance(r.t, TFractionCD):
|
|
||||||
return TypeAnnotation(TFractionCD(_lcm(l.t.denominator, r.t.denominator)), l.unit)
|
|
||||||
if isinstance(l.t, TFractionCD) and isinstance(r.t, TFraction):
|
|
||||||
return TypeAnnotation(TFraction())
|
|
||||||
if isinstance(l.t, TFraction) and isinstance(r.t, (TInt, TFractionCD, TFraction)):
|
|
||||||
return l
|
|
||||||
raise TypeError
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif isinstance(node, ast.Call):
|
elif isinstance(node, ast.Call):
|
||||||
if node.func.id == "bool":
|
return _get_call_type(sym_to_type, node)
|
||||||
return TypeAnnotation(TBool())
|
|
||||||
elif node.func.id == "float":
|
|
||||||
return TypeAnnotation(TFloat())
|
|
||||||
elif node.func.id == "int":
|
|
||||||
return TypeAnnotation(TInt(32))
|
|
||||||
elif node.func.id == "int64":
|
|
||||||
return TypeAnnotation(TInt(64))
|
|
||||||
elif node.func.id == "Fraction":
|
|
||||||
if len(node.args) == 2 and isinstance(node.args[1], ast.Num):
|
|
||||||
if not isinstance(node.args[1].n, int):
|
|
||||||
raise TypeError
|
|
||||||
return TypeAnnotation(TFractionCD(node.args[1].n))
|
|
||||||
else:
|
|
||||||
return TypeAnnotation(TFraction())
|
|
||||||
elif node.func.id == "Quantity":
|
|
||||||
ta = _get_expr_type(sym_to_type, node.args[0])
|
|
||||||
ta.unit = getattr(units, node.args[1].id)
|
|
||||||
return ta
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue