From 0453273d8b4e0c8df3ad0b40b51dfab19a317c44 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 23 Dec 2020 11:22:17 +0800 Subject: [PATCH] type guard --- toy-impl/README.md | 4 +--- toy-impl/examples/a.py | 12 +++++----- toy-impl/inference.py | 2 +- toy-impl/parse_expr.py | 13 +++++++++++ toy-impl/parse_stmt.py | 51 +++++++++++++++++++++++++++++++++++------- toy-impl/top_level.py | 4 ++-- 6 files changed, 66 insertions(+), 20 deletions(-) diff --git a/toy-impl/README.md b/toy-impl/README.md index f9fea16..3a8ecc0 100644 --- a/toy-impl/README.md +++ b/toy-impl/README.md @@ -19,9 +19,7 @@ for simplicity reasons: * `with`, `try except`, etc. * const indexing with tuple. * method override check modulo variable renaming. - -These features are currently not implemented, and would be added in due course: -* Type guards. +* more complicated type guard ## Type Check Implementation diff --git a/toy-impl/examples/a.py b/toy-impl/examples/a.py index 3ee1510..98c9a21 100644 --- a/toy-impl/examples/a.py +++ b/toy-impl/examples/a.py @@ -1,14 +1,14 @@ -I = TypeVar('I', int32, int64) +I = TypeVar('I', int32, Vec) class Vec: v: list[int32] def __init__(self, v: list[int32]): self.v = v - def __add__(self, other: int32) -> Vec: - return Vec([v + other for v in self.v]) + def __add__(self, other: I) -> Vec: + if other is int32: + return Vec([v + other for v in self.v]) + else: + return Vec([self.v[i] + other.v[i] for i in range(len(self.v))]) -def addI(a: I, b: I) -> I: - return a + b - diff --git a/toy-impl/inference.py b/toy-impl/inference.py index 1d4441d..0195f80 100644 --- a/toy-impl/inference.py +++ b/toy-impl/inference.py @@ -43,7 +43,7 @@ def find_subst(ctx: dict[str, Type], return sub if isinstance(a, TypeVariable): - if a == b: + if len(a.constraints) == 1 and a.constraints[0] == b: return sub else: return f"{a} can take values other than {b}" diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py index a4979b0..aa0d945 100644 --- a/toy-impl/parse_expr.py +++ b/toy-impl/parse_expr.py @@ -107,6 +107,19 @@ def parse_attribute(ctx: Context, sym_table: dict[str, Type], node): obj = parse_expr(ctx, sym_table, node.value) + + if isinstance(obj, TypeVariable) and len(obj.constraints) > 0: + if node.attr not in obj.constraints[0].fields: + raise CustomError(f'unknown field {node.attr} in {obj}', node) + ty = obj.constraints[0].fields[node.attr] + for v in obj.constraints[1:]: + if node.attr not in v.fields: + raise CustomError(f'unknown field {node.attr} in {obj}', node) + if v.fields[node.attr] != ty: + raise CustomError( + f'unknown field {node.attr} in {obj} (type mismatch)', node) + return ty + if node.attr in obj.fields: return obj.fields[node.attr] raise CustomError(f'unknown field {node.attr} in {obj}', node) diff --git a/toy-impl/parse_stmt.py b/toy-impl/parse_stmt.py index f8d1cab..4e17212 100644 --- a/toy-impl/parse_stmt.py +++ b/toy-impl/parse_stmt.py @@ -3,6 +3,7 @@ import copy from helper import * from type_def import * from inference import * +from top_level import parse_type from parse_expr import parse_expr, parse_simple_binding def parse_stmts(ctx: Context, @@ -113,14 +114,44 @@ def parse_if_stmt(ctx: Context, used_sym_table: dict[str, Type], return_ty: Type, node): - test = parse_expr(ctx, sym_table, node.test) - if test != ctx.types['bool']: - raise CustomError(f'condition must be bool instead of {test}', node.test) - a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) - used_sym_table |= b - a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) - defined = {k: a[k] for k in a if k in a1} - return defined, b | b1, r and r1 + if isinstance(node.test, ast.Compare) and \ + len(node.test.ops) == 1 and \ + (isinstance(node.test.ops[0], ast.Is) or\ + isinstance(node.test.ops[0], ast.IsNot)): + if not isinstance(node.test.left, ast.Name): + raise CustomError( + 'type guard only support testing variables', + node.test) + t = parse_expr(ctx, sym_table, node.test.left) + if not isinstance(t, TypeVariable) or len(t.constraints) < 2: + raise CustomError( + 'type guard only support basic type variables with constraints', + node.test) + t1, _ = parse_type(ctx, node.test.comparators[0]) + if t1 not in t.constraints: + raise CustomError( + f'{t1} is not in constraints of {t}', + node.test) + t2 = [v for v in t.constraints if v != t1] + try: + t.constraints = [t1] + a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) + used_sym_table |= b + t.constraints = t2 + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) + defined = {k: a[k] for k in a if k in a1} + return defined, b | b1, r and r1 + finally: + t.constraints = t + else: + test = parse_expr(ctx, sym_table, node.test) + if test != ctx.types['bool']: + raise CustomError(f'condition must be bool instead of {test}', node.test) + a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) + used_sym_table |= b + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) + defined = {k: a[k] for k in a if k in a1} + return defined, b | b1, r and r1 def parse_for_stmt(ctx: Context, sym_table: dict[str, Type], @@ -169,6 +200,10 @@ def parse_return_stmt(ctx: Context, return {}, {}, True ty = parse_expr(ctx, sym_table, node.value) if ty != return_ty: + if isinstance(return_ty, TypeVariable): + if len(return_ty.constraints) == 1 and \ + return_ty.constraints[0] == ty: + return {}, {}, True raise CustomError(f'expected returning {return_ty} but got {ty}', node) return {}, {}, True diff --git a/toy-impl/top_level.py b/toy-impl/top_level.py index 33ffa2c..86594c7 100644 --- a/toy-impl/top_level.py +++ b/toy-impl/top_level.py @@ -121,7 +121,7 @@ def parse_type_var(ctx: Context, node): not isinstance(node.value.args[0].value, str): raise CustomError('call to TypeVar must at least have a name', node.value) - name = node.value.args[0] + name = node.targets[0].id if name in ctx.variables: raise CustomError('redefining type variable is not allowed', node) constraints = [] @@ -137,7 +137,7 @@ def parse_type_var(ctx: Context, node): if value not in ctx.types: raise CustomError(f'unbounded type {value}', node) constraints.append(ctx.types[value]) - ctx.variables[node.targets[0].id] = TypeVariable(name, constraints) + ctx.variables[name] = TypeVariable(node.value.args[0].value, constraints) def parse_top_level(ctx: Context, module: ast.Module):