From 4577d0cc12a9cda8125aea3163e1501f9d4f7c5b Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 23 Dec 2020 16:05:45 +0800 Subject: [PATCH] fixed variable freshness bug --- toy-impl/README.md | 8 ++++---- toy-impl/examples/a.py | 11 +++++++++++ toy-impl/main.py | 9 ++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/toy-impl/README.md b/toy-impl/README.md index dabd3d3..8a61b41 100644 --- a/toy-impl/README.md +++ b/toy-impl/README.md @@ -52,7 +52,7 @@ doing unification. Consider the following example: -```py +```python X = TypeVar('X') def head(a: list[X]) -> X: @@ -68,7 +68,7 @@ algorithm tries to fit `(list[int32])` into `(list[X])`, giving a substitution Substitution can also substitute variables into another variable. Consider the following example: -``` +```python X = TypeVar('X') Y = TypeVar('Y', int32, int64) @@ -89,7 +89,7 @@ So the function is well typed. Note that variables are fresh in every invocation. Consider the following example: -``` +```python I = TypeVar('I', int32, list[int32]) def add(a: int32, b: I) -> int32: @@ -102,7 +102,7 @@ def add(a: int32, b: I) -> int32: add(1, [1, 2, 3]) ``` -This one should type check (bug now). `I -> list[int32]` only affects 1 call, +This one should type check. `I -> list[int32]` only affects 1 call, and the recursion inside could substitute `I -> int32`. ## Variable Scoping diff --git a/toy-impl/examples/a.py b/toy-impl/examples/a.py index 15165db..11616ff 100644 --- a/toy-impl/examples/a.py +++ b/toy-impl/examples/a.py @@ -14,3 +14,14 @@ class Vec: return Vec([self.v[i] + other.v[i] for i in range(len(self.v))]) +T = TypeVar('T', int32, list[int32]) + +def add(a: int32, b: T) -> int32: + if type(b) == int32: + return a + b + else: + for x in b: + a = add(a, x) + return a + + diff --git a/toy-impl/main.py b/toy-impl/main.py index dee3b1a..904a34c 100644 --- a/toy-impl/main.py +++ b/toy-impl/main.py @@ -1,5 +1,6 @@ import ast import sys +import copy from helper import CustomError from type_def import SelfType, ClassType from parse_stmt import parse_stmts @@ -24,9 +25,11 @@ try: for c, name, fn in fns: if c is None: - params, result, _ = ctx.functions[name] + params, result, var = ctx.functions[name] else: - params, result, _ = ctx.types[c].methods[name] + params, result, var = ctx.types[c].methods[name] + # create substitution for type variables + subst = {k: copy.deepcopy(ctx.variables[k]) for k in var} # check if fully annotated all params sym_table = {} for n, ty in zip(fn.args.args, params): @@ -36,7 +39,7 @@ try: fn) if isinstance(ty, SelfType): ty = ctx.types[c] - sym_table[n.arg] = ty + sym_table[n.arg] = ty.subst(subst) _, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body) if result is not None and not returned: raise CustomError('Function may have no return value', fn)