fixed for loop unification

This commit is contained in:
pca006132 2021-07-12 10:36:52 +08:00
parent 1ad21f0d67
commit 39b3faba6e
3 changed files with 11 additions and 14 deletions

View File

@ -172,7 +172,7 @@ class Visitor(ast.NodeVisitor):
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
ty = node.iter.type
ty = node.iter.type.find()
if isinstance(ty, TVar):
# we currently only support iterator over lists
ty.type = ty.type.unifier(TVarType.LIST)

View File

@ -122,14 +122,8 @@ class TVar(Type):
elif isinstance(y, TVar):
# check fields
if isinstance(x, TVirtual):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
if k not in x.obj.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
u = x.obj.fields[k]
v.unify(u)
obj = x.obj.find()
self.unify(obj)
elif isinstance(x, TObj):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}')
@ -397,6 +391,8 @@ class TVirtual(Type):
o = other.find()
if isinstance(o, TVirtual):
self.obj.unify(o.obj)
elif isinstance(o, TVar):
o.unify(self)
else:
raise UnificationError(f'Cannot unify {self} with {o}')

View File

@ -5,13 +5,14 @@ from nac3_types import *
from primitives import *
src = """
a = virtual(bar)
b = test_virtual(a)
a = virtual(foo)
b = test_virtual(a)
a = [
virtual(bar),
virtual(foo),
]
c = virtual(foo, Foo)
for x in a:
test_virtual(x)
"""
foo = TObj('Foo', {