From 39b3faba6ee7e94e127d3825a415dde00c1f553c Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 12 Jul 2021 10:36:52 +0800 Subject: [PATCH] fixed for loop unification --- hm-inference/ast_visitor.py | 2 +- hm-inference/nac3_types.py | 12 ++++-------- hm-inference/test.py | 11 ++++++----- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index d1f81ff..0f8c2b0 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -172,7 +172,7 @@ class Visitor(ast.NodeVisitor): def visit_For(self, node): self.visit(node.target) self.visit(node.iter) - ty = node.iter.type + ty = node.iter.type.find() if isinstance(ty, TVar): # we currently only support iterator over lists ty.type = ty.type.unifier(TVarType.LIST) diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index 2401c2b..f5d9caa 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -122,14 +122,8 @@ class TVar(Type): elif isinstance(y, TVar): # check fields if isinstance(x, TVirtual): - if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: - raise UnificationError(f'Cannot unify {y} with {x}') - for k, v in y.fields.items(): - if k not in x.obj.fields: - raise UnificationError( - f'Cannot unify {y} with {x}') - u = x.obj.fields[k] - v.unify(u) + obj = x.obj.find() + self.unify(obj) elif isinstance(x, TObj): if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: raise UnificationError(f'Cannot unify {y} with {x}') @@ -397,6 +391,8 @@ class TVirtual(Type): o = other.find() if isinstance(o, TVirtual): self.obj.unify(o.obj) + elif isinstance(o, TVar): + o.unify(self) else: raise UnificationError(f'Cannot unify {self} with {o}') diff --git a/hm-inference/test.py b/hm-inference/test.py index 8595807..18a89ec 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -5,13 +5,14 @@ from nac3_types import * from primitives import * src = """ -a = virtual(bar) -b = test_virtual(a) -a = virtual(foo) -b = test_virtual(a) +a = [ + virtual(bar), + virtual(foo), +] -c = virtual(foo, Foo) +for x in a: + test_virtual(x) """ foo = TObj('Foo', {