Compare commits
11 Commits
Author | SHA1 | Date |
---|---|---|
pca006132 | 39b3faba6e | |
pca006132 | 1ad21f0d67 | |
pca006132 | dbf9c17d9f | |
pca006132 | 0010e5852a | |
pca006132 | 902e1a892c | |
pca006132 | b1020352ce | |
pca006132 | 66df55b3d7 | |
pca006132 | 0c8029d7e1 | |
pca006132 | e15e29d673 | |
pca006132 | 1c17ed003e | |
pca006132 | 59628cfa38 |
|
@ -0,0 +1,50 @@
|
||||||
|
# Type Inference
|
||||||
|
|
||||||
|
This is a prototype before Rust implementation of the algorithm.
|
||||||
|
|
||||||
|
## Implemented Features
|
||||||
|
- Primitive types: bool, float, integer. (limited magic methods are implemented,
|
||||||
|
no casting for now)
|
||||||
|
- Statements:
|
||||||
|
- For loop.
|
||||||
|
- While loop.
|
||||||
|
- If expression.
|
||||||
|
- Simple assignment. (without type annotation, allows pattern matcing for
|
||||||
|
tuples)
|
||||||
|
- Expressions:
|
||||||
|
- Boolean operations.
|
||||||
|
- Binary operations. (+, -, *, /, //)
|
||||||
|
- Compare.
|
||||||
|
- Function call. (posargs/kwargs/optional are supported)
|
||||||
|
- Lambda.
|
||||||
|
- Object Attribute.
|
||||||
|
- List/tuple subscript. (tuple requires constant indexing)
|
||||||
|
- Virtual.
|
||||||
|
- Constraints for type variables.
|
||||||
|
|
||||||
|
User can define functions/types by adding them to the prelude.
|
||||||
|
|
||||||
|
Note that variables can be used before definition for now, we would do another
|
||||||
|
pass after type checking to prevent this. The pass would also check for return.
|
||||||
|
|
||||||
|
We assume that class type do not contain any unbound type variables. (methods
|
||||||
|
can contain unbound type variables, if the variable is bound to the class)
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- Parse class/function definition.
|
||||||
|
- Occur check to prevent infinite function types.
|
||||||
|
- Pretty print for types. (especially type variables)
|
||||||
|
- Better error message. (we did not keep any context for now)
|
||||||
|
|
||||||
|
## Implementation Notes
|
||||||
|
|
||||||
|
- Type variables would now retain a lot of information, including fields and its
|
||||||
|
range for later checking.
|
||||||
|
- Function type is seperated from call type. We maintain a list of calls to the
|
||||||
|
same function (potentially with different type signature) to allow different
|
||||||
|
instantiation of the same function.
|
||||||
|
- We store the list of calls and the list of virtual types, and check the
|
||||||
|
constraints at the end of the type inference phase. E.g. check subtyping
|
||||||
|
relationship, check whether the type variable can be instantiated to a certain
|
||||||
|
type.
|
||||||
|
|
|
@ -0,0 +1,237 @@
|
||||||
|
import ast
|
||||||
|
from itertools import chain
|
||||||
|
from nac3_types import *
|
||||||
|
from primitives import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_magic_method(op):
|
||||||
|
if isinstance(op, ast.Add):
|
||||||
|
return '__add__'
|
||||||
|
if isinstance(op, ast.Sub):
|
||||||
|
return '__sub__'
|
||||||
|
if isinstance(op, ast.Mult):
|
||||||
|
return '__mul__'
|
||||||
|
if isinstance(op, ast.Div):
|
||||||
|
return '__truediv__'
|
||||||
|
if isinstance(op, ast.FloorDiv):
|
||||||
|
return '__floordiv__'
|
||||||
|
if isinstance(op, ast.Eq):
|
||||||
|
return '__eq__'
|
||||||
|
if isinstance(op, ast.NotEq):
|
||||||
|
return '__ne__'
|
||||||
|
if isinstance(op, ast.Lt):
|
||||||
|
return '__lt__'
|
||||||
|
if isinstance(op, ast.LtE):
|
||||||
|
return '__le__'
|
||||||
|
if isinstance(op, ast.Gt):
|
||||||
|
return '__gt__'
|
||||||
|
if isinstance(op, ast.GtE):
|
||||||
|
return '__ge__'
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
|
class Visitor(ast.NodeVisitor):
|
||||||
|
def __init__(self, src, assignments, type_parser):
|
||||||
|
super(Visitor, self).__init__()
|
||||||
|
self.source = src
|
||||||
|
self.assignments = assignments
|
||||||
|
self.calls = []
|
||||||
|
self.virtuals = []
|
||||||
|
self.type_parser = type_parser
|
||||||
|
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
self.visit(node.value)
|
||||||
|
for target in node.targets:
|
||||||
|
self.visit(target)
|
||||||
|
target.type.unify(node.value.type)
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
if node.id == '_':
|
||||||
|
node.type = TVar()
|
||||||
|
return
|
||||||
|
if node.id not in self.assignments:
|
||||||
|
self.assignments[node.id] = TVar()
|
||||||
|
ty = self.assignments[node.id]
|
||||||
|
if isinstance(ty, TFunc):
|
||||||
|
ty = ty.instantiate()
|
||||||
|
node.type = ty
|
||||||
|
|
||||||
|
def visit_Lambda(self, node):
|
||||||
|
old = self.assignments
|
||||||
|
self.assignments = old.copy()
|
||||||
|
self.visit(node.args)
|
||||||
|
self.visit(node.body)
|
||||||
|
self.assignments = old
|
||||||
|
node.type = TFunc(node.args.type, node.body.type, [])
|
||||||
|
|
||||||
|
def visit_arguments(self, node):
|
||||||
|
for arg in node.args:
|
||||||
|
self.assignments[arg.arg] = TVar()
|
||||||
|
arg.type = self.assignments[arg.arg]
|
||||||
|
node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args]
|
||||||
|
|
||||||
|
def visit_Constant(self, node):
|
||||||
|
if isinstance(node.value, bool):
|
||||||
|
node.type = TBool
|
||||||
|
elif isinstance(node.value, int):
|
||||||
|
node.type = TInt
|
||||||
|
elif isinstance(node.value, float):
|
||||||
|
node.type = TFloat
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
||||||
|
if len(node.args) > 2 or len(node.args) < 1:
|
||||||
|
raise UnificationError('Incorrect argument number for virtual')
|
||||||
|
self.visit(node.args[0])
|
||||||
|
if len(node.args) == 2:
|
||||||
|
ty = self.type_parser(ast.get_source_segment(self.source,
|
||||||
|
node.args[1]))
|
||||||
|
else:
|
||||||
|
ty = TVar()
|
||||||
|
self.virtuals.append((node.args[0].type, ty))
|
||||||
|
node.type = TVirtual(ty)
|
||||||
|
return
|
||||||
|
self.visit(node.func)
|
||||||
|
for arg in node.args:
|
||||||
|
self.visit(arg)
|
||||||
|
for keyword in node.keywords:
|
||||||
|
self.visit(keyword.value)
|
||||||
|
node.type = TVar()
|
||||||
|
call = TCall([arg.type for arg in node.args],
|
||||||
|
{keyword.arg: keyword.value.type for keyword
|
||||||
|
in node.keywords}, node.type)
|
||||||
|
node.func.type.unify(call)
|
||||||
|
self.calls.append(call)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node):
|
||||||
|
self.visit(node.value)
|
||||||
|
node.type = TVar()
|
||||||
|
v = TVar()
|
||||||
|
v.type = TVarType.RECORD
|
||||||
|
v.fields[node.attr] = node.type
|
||||||
|
node.value.type.unify(v)
|
||||||
|
|
||||||
|
def visit_Tuple(self, node):
|
||||||
|
for e in node.elts:
|
||||||
|
self.visit(e)
|
||||||
|
node.type = TTuple([e.type for e in node.elts])
|
||||||
|
|
||||||
|
def visit_List(self, node):
|
||||||
|
ty = TVar()
|
||||||
|
for e in node.elts:
|
||||||
|
self.visit(e)
|
||||||
|
ty.unify(e.type)
|
||||||
|
node.type = TList(ty)
|
||||||
|
|
||||||
|
def visit_Subscript(self, node):
|
||||||
|
self.visit(node.value)
|
||||||
|
if isinstance(node.slice, ast.Slice):
|
||||||
|
node.type = node.value.type
|
||||||
|
if isinstance(node.type, TVar):
|
||||||
|
node.type.type = node.type.type.unifier(TVarType.LIST)
|
||||||
|
elif not isinstance(node.type, TList):
|
||||||
|
raise UnificationError(f'{node.type} should be a list')
|
||||||
|
elif isinstance(node.slice, ast.ExtSlice):
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
# complicated because we need to handle lists and tuples
|
||||||
|
# differently...
|
||||||
|
self.visit(node.slice.value)
|
||||||
|
node.slice.value.type.unify(TInt)
|
||||||
|
ty = node.value.type
|
||||||
|
node.type = TVar()
|
||||||
|
if isinstance(ty, TVar):
|
||||||
|
index = 0
|
||||||
|
if isinstance(node.slice.value, ast.Constant):
|
||||||
|
seq_ty = TVarType.SEQUENCE
|
||||||
|
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
|
||||||
|
seq_ty = TVarType.LIST
|
||||||
|
ty.type = ty.type.unifier(seq_ty)
|
||||||
|
index = node.slice.value.value
|
||||||
|
else:
|
||||||
|
ty.type = ty.type.unifier(TVarType.LIST)
|
||||||
|
if index in ty.fields:
|
||||||
|
ty.fields[index].unify(node.type)
|
||||||
|
else:
|
||||||
|
ty.fields[index] = node.type
|
||||||
|
elif isinstance(ty, TList):
|
||||||
|
ty.param.unify(node.type)
|
||||||
|
elif isinstance(ty, TTuple):
|
||||||
|
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
|
||||||
|
raise UnificationError(f'Cannot assign to tuple')
|
||||||
|
if isinstance(node.slice.value, ast.Constant):
|
||||||
|
index = node.slice.value.value
|
||||||
|
if index >= len(ty.params):
|
||||||
|
raise UnificationError('Index out of range for tuple')
|
||||||
|
ty.params[index].unify(node.type)
|
||||||
|
else:
|
||||||
|
raise UnificationError('Tuple index must be a constant')
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot use subscript for {ty}')
|
||||||
|
|
||||||
|
def visit_For(self, node):
|
||||||
|
self.visit(node.target)
|
||||||
|
self.visit(node.iter)
|
||||||
|
ty = node.iter.type.find()
|
||||||
|
if isinstance(ty, TVar):
|
||||||
|
# we currently only support iterator over lists
|
||||||
|
ty.type = ty.type.unifier(TVarType.LIST)
|
||||||
|
if 0 in ty.fields:
|
||||||
|
ty.fields[0].unify(node.target.type)
|
||||||
|
else:
|
||||||
|
ty.fields[0] = node.target.type
|
||||||
|
elif isinstance(ty, TList):
|
||||||
|
ty.param.unify(node.target.type)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot iterate over {ty}')
|
||||||
|
|
||||||
|
for stmt in chain(node.body, node.orelse):
|
||||||
|
self.visit(stmt)
|
||||||
|
|
||||||
|
def visit_If(self, node):
|
||||||
|
self.visit(node.test)
|
||||||
|
node.test.type.unify(TBool)
|
||||||
|
for stmt in chain(node.body, node.orelse):
|
||||||
|
self.visit(stmt)
|
||||||
|
|
||||||
|
def visit_While(self, node):
|
||||||
|
self.visit(node.test)
|
||||||
|
node.test.type.unify(TBool)
|
||||||
|
for stmt in chain(node.body, node.orelse):
|
||||||
|
self.visit(stmt)
|
||||||
|
|
||||||
|
def visit_BoolOp(self, node):
|
||||||
|
self.visit(node.values[0])
|
||||||
|
self.visit(node.values[1])
|
||||||
|
node.values[0].type.unify(TBool)
|
||||||
|
node.values[1].type.unify(TBool)
|
||||||
|
node.type = TBool
|
||||||
|
|
||||||
|
def visit_BinOp(self, node):
|
||||||
|
self.visit(node.left)
|
||||||
|
self.visit(node.right)
|
||||||
|
# call method...
|
||||||
|
method = get_magic_method(node.op)
|
||||||
|
ret = TVar()
|
||||||
|
node.type = ret
|
||||||
|
call = TCall([node.right.type], {}, ret)
|
||||||
|
self.calls.append(call)
|
||||||
|
v = TVar()
|
||||||
|
v.type = TVarType.RECORD
|
||||||
|
v.fields[method] = call
|
||||||
|
node.left.type.unify(v)
|
||||||
|
|
||||||
|
def visit_Compare(self, node):
|
||||||
|
self.visit(node.left)
|
||||||
|
for c in node.comparators:
|
||||||
|
self.visit(c)
|
||||||
|
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
|
||||||
|
node.comparators, node.ops):
|
||||||
|
method = get_magic_method(c)
|
||||||
|
call = TCall([b.type], {}, TBool)
|
||||||
|
self.calls.append(call)
|
||||||
|
v = TVar()
|
||||||
|
v.type = TVarType.RECORD
|
||||||
|
v.fields[method] = call
|
||||||
|
a.type.unify(v)
|
||||||
|
node.type = TBool
|
|
@ -0,0 +1,472 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Dict, Mapping, List, Set
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class UnificationError(Exception):
|
||||||
|
def __init__(self, msg):
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class Type:
|
||||||
|
def find(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unify(self, _):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def subst(self, _):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TVarType(Enum):
|
||||||
|
UNDETERMINED = 1
|
||||||
|
SEQUENCE = 2
|
||||||
|
RECORD = 5
|
||||||
|
TUPLE = 6
|
||||||
|
LIST = 8
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
if self.__class__ is other.__class__:
|
||||||
|
return (other.value % self.value) == 0
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def unifier(self, other):
|
||||||
|
if self.__class__ is not other.__class__:
|
||||||
|
raise NotImplementedError()
|
||||||
|
if self <= other:
|
||||||
|
return other
|
||||||
|
elif other <= self:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'cannot unify {self} and {other}')
|
||||||
|
|
||||||
|
|
||||||
|
class TVar(Type):
|
||||||
|
next_id = 0
|
||||||
|
|
||||||
|
def __init__(self, vrange=None):
|
||||||
|
self.type = TVarType.UNDETERMINED
|
||||||
|
self.rank = 0
|
||||||
|
self.parent = self
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
self.fields = {}
|
||||||
|
self.range = vrange
|
||||||
|
self.id = TVar.next_id
|
||||||
|
TVar.next_id += 1
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
if self.range is not None:
|
||||||
|
ty = self.find()
|
||||||
|
# maybe we should replace this with explicit eq
|
||||||
|
if ty is not self and ty not in self.range:
|
||||||
|
raise UnificationError(
|
||||||
|
f'{self.id} cannot be substituted by {ty}')
|
||||||
|
|
||||||
|
def subst(self, mapping: Mapping[int, Type]):
|
||||||
|
# user cannot specify fields...
|
||||||
|
# so this is safe
|
||||||
|
if self.id in mapping:
|
||||||
|
return mapping[self.id]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = self.find()
|
||||||
|
if isinstance(s, TVar):
|
||||||
|
if len(s.fields) > 0:
|
||||||
|
fields = '{' + ', '.join([f'{k}: {v}' for k,
|
||||||
|
v in s.fields.items()]) + '}'
|
||||||
|
else:
|
||||||
|
fields = ''
|
||||||
|
return str(s.id) + fields
|
||||||
|
else:
|
||||||
|
return str(s)
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
root = self
|
||||||
|
parent = self.parent
|
||||||
|
while root is not parent and isinstance(parent, TVar):
|
||||||
|
_, parent = root, root.parent = parent, parent.parent
|
||||||
|
if isinstance(parent, TCall):
|
||||||
|
parent = parent.find()
|
||||||
|
return parent
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
x = other.find()
|
||||||
|
y = self.find()
|
||||||
|
|
||||||
|
if x is y:
|
||||||
|
return
|
||||||
|
if isinstance(y, TVar) and isinstance(x, TVar):
|
||||||
|
# unify field type
|
||||||
|
x.type = x.type.unifier(y.type)
|
||||||
|
# unify fields
|
||||||
|
for k, v in y.fields.items():
|
||||||
|
if k in x.fields:
|
||||||
|
x.fields[k].unify(v)
|
||||||
|
else:
|
||||||
|
x.fields[k] = v
|
||||||
|
# standard union find
|
||||||
|
if x.rank < y.rank:
|
||||||
|
x, y = y, x
|
||||||
|
y.parent = x
|
||||||
|
if x.rank == y.rank:
|
||||||
|
x.rank += 1
|
||||||
|
elif isinstance(y, TVar):
|
||||||
|
# check fields
|
||||||
|
if isinstance(x, TVirtual):
|
||||||
|
obj = x.obj.find()
|
||||||
|
self.unify(obj)
|
||||||
|
elif isinstance(x, TObj):
|
||||||
|
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
||||||
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
|
for k, v in y.fields.items():
|
||||||
|
if k not in x.fields:
|
||||||
|
raise UnificationError(
|
||||||
|
f'Cannot unify {y} with {x}')
|
||||||
|
u = x.fields[k]
|
||||||
|
v.unify(u)
|
||||||
|
elif isinstance(x, TList):
|
||||||
|
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
||||||
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
|
for k, v in y.fields.items():
|
||||||
|
assert isinstance(k, int)
|
||||||
|
v.unify(x.param)
|
||||||
|
elif isinstance(x, TTuple):
|
||||||
|
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
|
||||||
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
|
for k, v in y.fields.items():
|
||||||
|
assert isinstance(k, int)
|
||||||
|
if k >= len(x.params):
|
||||||
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
|
v.unify(x.params[k])
|
||||||
|
y.parent = x
|
||||||
|
else:
|
||||||
|
y.unify(x)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
s = self.find()
|
||||||
|
o = other.find()
|
||||||
|
if not isinstance(s, TVar):
|
||||||
|
return s == o
|
||||||
|
if isinstance(o, TVar):
|
||||||
|
return s.id == o.id
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FuncArg:
|
||||||
|
def __init__(self, name, typ, is_optional):
|
||||||
|
self.name = name
|
||||||
|
self.typ = typ
|
||||||
|
self.is_optional = is_optional
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'{self.name}: {self.typ}' + ('?' if self.is_optional else '')
|
||||||
|
|
||||||
|
|
||||||
|
class TCall(Type):
|
||||||
|
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
|
||||||
|
self.calls = [[posargs, kwargs, ret, None]]
|
||||||
|
self.parent = self
|
||||||
|
self.rank = 0
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
self.calls[0][3].check()
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
root = self
|
||||||
|
parent = self.parent
|
||||||
|
while root is not parent and isinstance(parent, TCall):
|
||||||
|
_, parent = root, root.parent = parent, parent.parent
|
||||||
|
if parent.calls[0][3] is None:
|
||||||
|
return parent
|
||||||
|
return parent.calls[0][3]
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
y = self.find()
|
||||||
|
if y is not self:
|
||||||
|
y.unify(other)
|
||||||
|
return
|
||||||
|
|
||||||
|
x = other.find()
|
||||||
|
if x is y:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(x, TCall):
|
||||||
|
# standard union find
|
||||||
|
if x.rank < y.rank:
|
||||||
|
x, y = y, x
|
||||||
|
y.parent = x
|
||||||
|
if x.rank == y.rank:
|
||||||
|
x.rank += 1
|
||||||
|
x.calls += y.calls
|
||||||
|
elif isinstance(x, TFunc):
|
||||||
|
fn = x
|
||||||
|
for i in range(len(y.calls)):
|
||||||
|
if not x.instantiated:
|
||||||
|
fn = x.instantiate()
|
||||||
|
posargs, kwargs, ret, _ = y.calls[i]
|
||||||
|
c = y.calls[i]
|
||||||
|
c[3] = fn
|
||||||
|
all_args = set(arg.name for arg in fn.args)
|
||||||
|
required = set(arg.name for arg in fn.args if not
|
||||||
|
arg.is_optional)
|
||||||
|
fn.ret.unify(ret)
|
||||||
|
for i, v in enumerate(posargs):
|
||||||
|
arg = fn.args[i]
|
||||||
|
arg.typ.unify(v)
|
||||||
|
if arg.name in required:
|
||||||
|
required.remove(arg.name)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
arg = next((arg for arg in fn.args if arg.name == k), None)
|
||||||
|
if arg is None:
|
||||||
|
raise UnificationError(f'Unknown kwarg {k}')
|
||||||
|
if k not in all_args:
|
||||||
|
raise UnificationError(f'Duplicated kwarg {k}')
|
||||||
|
arg.typ.unify(v)
|
||||||
|
if k in required:
|
||||||
|
required.remove(k)
|
||||||
|
all_args.remove(k)
|
||||||
|
if len(required) > 0:
|
||||||
|
raise UnificationError(f'Missing arguments')
|
||||||
|
elif isinstance(other, TVar):
|
||||||
|
other.unify(self)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify a call with {other}')
|
||||||
|
|
||||||
|
|
||||||
|
class TFunc(Type):
|
||||||
|
def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
|
||||||
|
self.args = args
|
||||||
|
self.ret = ret
|
||||||
|
self.vars = vars
|
||||||
|
self.instantiated = False
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
for arg in self.args:
|
||||||
|
arg.typ.check()
|
||||||
|
self.ret.check()
|
||||||
|
|
||||||
|
def subst(self, mapping: Mapping[int, Type]):
|
||||||
|
if len(mapping) == 0:
|
||||||
|
return self
|
||||||
|
return TFunc(
|
||||||
|
[FuncArg(arg.name, arg.typ.subst(mapping), arg.is_optional)
|
||||||
|
for arg in self.args],
|
||||||
|
self.ret.subst(mapping),
|
||||||
|
self.vars)
|
||||||
|
|
||||||
|
def instantiate(self):
|
||||||
|
mapping = {v.id: TVar(v.range) if isinstance(v, TVar)
|
||||||
|
else TVar() for v in self.vars}
|
||||||
|
result = self.subst(mapping)
|
||||||
|
result.instantiated = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'({", ".join(str(arg) for arg in self.args)}) -> {self.ret}'
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
other = other.find()
|
||||||
|
if other is self:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(other, (TVar, TCall)):
|
||||||
|
other.unify(self)
|
||||||
|
elif isinstance(other, TFunc):
|
||||||
|
if len(self.args) != len(other.args):
|
||||||
|
raise UnificationError(
|
||||||
|
f'cannot unify functions with different parameters')
|
||||||
|
self.ret.unify(other.ret)
|
||||||
|
for a, b in zip(self.args, other.args):
|
||||||
|
if a.name != b.name or a.is_optional != b.is_optional:
|
||||||
|
raise UnificationError(
|
||||||
|
f'cannot unify functions with different parameters')
|
||||||
|
a.typ.unify(b.typ)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify a function with {other}')
|
||||||
|
|
||||||
|
|
||||||
|
class TObj(Type):
|
||||||
|
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type],
|
||||||
|
parents=None):
|
||||||
|
self.name = name
|
||||||
|
self.fields = fields
|
||||||
|
self.params = params
|
||||||
|
self.checked = False
|
||||||
|
if parents is None:
|
||||||
|
self.parents = []
|
||||||
|
else:
|
||||||
|
self.parents = parents
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
for arg in self.fields.values():
|
||||||
|
arg.check()
|
||||||
|
|
||||||
|
def subst(self, mapping: Mapping[int, Type]):
|
||||||
|
if len(mapping) == 0:
|
||||||
|
return self
|
||||||
|
new_params = []
|
||||||
|
changed = False
|
||||||
|
for v in self.params:
|
||||||
|
if isinstance(v, TVar) and v.id in mapping:
|
||||||
|
new_params.append(mapping[v.id])
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
new_params.append(v)
|
||||||
|
if not changed:
|
||||||
|
return self
|
||||||
|
return TObj(self.name, {k: v.subst(mapping) for k, v in
|
||||||
|
self.fields.items()}, new_params)
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
other = other.find()
|
||||||
|
if other is self:
|
||||||
|
return
|
||||||
|
if isinstance(other, TObj):
|
||||||
|
if self.name != other.name:
|
||||||
|
raise UnificationError(f'Cannot unify {self} with {other}')
|
||||||
|
for k in self.fields:
|
||||||
|
self.fields[k].unify(other.fields[k])
|
||||||
|
elif isinstance(other, TVar):
|
||||||
|
other.unify(self)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify an object with {other}')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if len(self.params) > 0:
|
||||||
|
p = '[' + ', '.join(str(p) for p in self.params) + ']'
|
||||||
|
else:
|
||||||
|
p = ''
|
||||||
|
return self.name + p
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TObj):
|
||||||
|
if self.name != o.name:
|
||||||
|
return False
|
||||||
|
for a, b in zip(self.params, o.params):
|
||||||
|
if a != b:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TVirtual(Type):
|
||||||
|
def __init__(self, obj: Type):
|
||||||
|
self.obj = obj
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
self.obj.check()
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TVirtual):
|
||||||
|
return self == o
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TVirtual):
|
||||||
|
self.obj.unify(o.obj)
|
||||||
|
elif isinstance(o, TVar):
|
||||||
|
o.unify(self)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify {self} with {o}')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'virtual[{self.obj}]'
|
||||||
|
|
||||||
|
|
||||||
|
class TList(Type):
|
||||||
|
def __init__(self, param: Type):
|
||||||
|
self.param = param
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
self.param.check()
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
other = other.find()
|
||||||
|
if isinstance(other, TVar):
|
||||||
|
other.unify(self)
|
||||||
|
elif isinstance(other, TList):
|
||||||
|
self.param.unify(other.param)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify list with {other}')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'List[{self.param}]'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TList):
|
||||||
|
return self.param == o.param
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TTuple(Type):
|
||||||
|
def __init__(self, params: List[Type]):
|
||||||
|
self.params = params
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
for p in self.params:
|
||||||
|
p.check()
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
other = other.find()
|
||||||
|
if isinstance(other, TVar):
|
||||||
|
other.unify(self)
|
||||||
|
elif isinstance(other, TTuple):
|
||||||
|
if len(self.params) != len(other.params):
|
||||||
|
raise UnificationError(
|
||||||
|
'cannot unify tuples with different length')
|
||||||
|
for a, b in zip(self.params, other.params):
|
||||||
|
a.unify(b)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify {self} with {other}')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TTuple):
|
||||||
|
if len(self.params) != len(o.params):
|
||||||
|
return False
|
||||||
|
for a, b in zip(self.params, o.params):
|
||||||
|
if a != b:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
from nac3_types import *
|
||||||
|
|
||||||
|
|
||||||
|
TBool = TObj('bool', {}, [])
|
||||||
|
TInt = TObj('int', {}, [])
|
||||||
|
TFloat = TObj('float', {}, [])
|
||||||
|
|
||||||
|
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||||
|
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||||
|
|
||||||
|
|
||||||
|
def impl_cmp(ty):
|
||||||
|
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
|
||||||
|
|
||||||
|
def impl_arithmetic(ty):
|
||||||
|
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
|
||||||
|
|
||||||
|
impl_cmp(TInt)
|
||||||
|
impl_cmp(TFloat)
|
||||||
|
impl_arithmetic(TInt)
|
||||||
|
impl_arithmetic(TFloat)
|
||||||
|
|
||||||
|
TNum = TVar([TInt, TFloat])
|
||||||
|
|
||||||
|
TInt.fields['__truediv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
TInt.fields['__floordiv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TInt, [TNum])
|
||||||
|
TFloat.fields['__truediv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
TFloat.fields['__floordiv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import ast
|
||||||
|
from ast_visitor import Visitor
|
||||||
|
from nac3_types import *
|
||||||
|
from primitives import *
|
||||||
|
|
||||||
|
src = """
|
||||||
|
|
||||||
|
a = [
|
||||||
|
virtual(bar),
|
||||||
|
virtual(foo),
|
||||||
|
]
|
||||||
|
|
||||||
|
for x in a:
|
||||||
|
test_virtual(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
foo = TObj('Foo', {
|
||||||
|
'a': TInt,
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
foo2 = TObj('Foo2', {
|
||||||
|
'a': TInt,
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
bar = TObj('Bar', {
|
||||||
|
'a': TInt,
|
||||||
|
'b': TInt
|
||||||
|
}, [], [foo])
|
||||||
|
|
||||||
|
type_mapping = {
|
||||||
|
'Foo': foo,
|
||||||
|
'Foo2': foo2,
|
||||||
|
'Bar': bar,
|
||||||
|
}
|
||||||
|
|
||||||
|
prelude = {
|
||||||
|
'foo': foo,
|
||||||
|
'foo2': foo2,
|
||||||
|
'bar': bar,
|
||||||
|
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
|
||||||
|
}
|
||||||
|
|
||||||
|
print('-----------')
|
||||||
|
print('prelude')
|
||||||
|
for key, value in prelude.items():
|
||||||
|
print(f'{key}: {value}')
|
||||||
|
print('-----------')
|
||||||
|
|
||||||
|
v = Visitor(src, prelude.copy(), lambda x: type_mapping[x])
|
||||||
|
|
||||||
|
print(src)
|
||||||
|
v.visit(ast.parse(src))
|
||||||
|
|
||||||
|
for a, b in v.virtuals:
|
||||||
|
assert isinstance(a, TObj)
|
||||||
|
assert b.find() is a or b.find() in a.parents
|
||||||
|
|
||||||
|
print('-----------')
|
||||||
|
print('calls')
|
||||||
|
for x in v.calls:
|
||||||
|
x.check()
|
||||||
|
print(f'{x.find()}')
|
||||||
|
|
||||||
|
print('-----------')
|
||||||
|
print('assignments')
|
||||||
|
for key, value in v.assignments.items():
|
||||||
|
if key not in prelude:
|
||||||
|
value.check()
|
||||||
|
print(f'{key}: {value.find()}')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue