type guard

pull/14/head
pca006132 2020-12-23 11:22:17 +08:00 committed by pca006132
parent b8719e5214
commit 0453273d8b
6 changed files with 66 additions and 20 deletions

View File

@ -19,9 +19,7 @@ for simplicity reasons:
* `with`, `try except`, etc.
* const indexing with tuple.
* method override check modulo variable renaming.
These features are currently not implemented, and would be added in due course:
* Type guards.
* more complicated type guard
## Type Check Implementation

View File

@ -1,14 +1,14 @@
I = TypeVar('I', int32, int64)
I = TypeVar('I', int32, Vec)
class Vec:
v: list[int32]
def __init__(self, v: list[int32]):
self.v = v
def __add__(self, other: int32) -> Vec:
return Vec([v + other for v in self.v])
def __add__(self, other: I) -> Vec:
if other is int32:
return Vec([v + other for v in self.v])
else:
return Vec([self.v[i] + other.v[i] for i in range(len(self.v))])
def addI(a: I, b: I) -> I:
return a + b

View File

@ -43,7 +43,7 @@ def find_subst(ctx: dict[str, Type],
return sub
if isinstance(a, TypeVariable):
if a == b:
if len(a.constraints) == 1 and a.constraints[0] == b:
return sub
else:
return f"{a} can take values other than {b}"

View File

@ -107,6 +107,19 @@ def parse_attribute(ctx: Context,
sym_table: dict[str, Type],
node):
obj = parse_expr(ctx, sym_table, node.value)
if isinstance(obj, TypeVariable) and len(obj.constraints) > 0:
if node.attr not in obj.constraints[0].fields:
raise CustomError(f'unknown field {node.attr} in {obj}', node)
ty = obj.constraints[0].fields[node.attr]
for v in obj.constraints[1:]:
if node.attr not in v.fields:
raise CustomError(f'unknown field {node.attr} in {obj}', node)
if v.fields[node.attr] != ty:
raise CustomError(
f'unknown field {node.attr} in {obj} (type mismatch)', node)
return ty
if node.attr in obj.fields:
return obj.fields[node.attr]
raise CustomError(f'unknown field {node.attr} in {obj}', node)

View File

@ -3,6 +3,7 @@ import copy
from helper import *
from type_def import *
from inference import *
from top_level import parse_type
from parse_expr import parse_expr, parse_simple_binding
def parse_stmts(ctx: Context,
@ -113,14 +114,44 @@ def parse_if_stmt(ctx: Context,
used_sym_table: dict[str, Type],
return_ty: Type,
node):
test = parse_expr(ctx, sym_table, node.test)
if test != ctx.types['bool']:
raise CustomError(f'condition must be bool instead of {test}', node.test)
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
used_sym_table |= b
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1
if isinstance(node.test, ast.Compare) and \
len(node.test.ops) == 1 and \
(isinstance(node.test.ops[0], ast.Is) or\
isinstance(node.test.ops[0], ast.IsNot)):
if not isinstance(node.test.left, ast.Name):
raise CustomError(
'type guard only support testing variables',
node.test)
t = parse_expr(ctx, sym_table, node.test.left)
if not isinstance(t, TypeVariable) or len(t.constraints) < 2:
raise CustomError(
'type guard only support basic type variables with constraints',
node.test)
t1, _ = parse_type(ctx, node.test.comparators[0])
if t1 not in t.constraints:
raise CustomError(
f'{t1} is not in constraints of {t}',
node.test)
t2 = [v for v in t.constraints if v != t1]
try:
t.constraints = [t1]
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
used_sym_table |= b
t.constraints = t2
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1
finally:
t.constraints = t
else:
test = parse_expr(ctx, sym_table, node.test)
if test != ctx.types['bool']:
raise CustomError(f'condition must be bool instead of {test}', node.test)
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
used_sym_table |= b
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1
def parse_for_stmt(ctx: Context,
sym_table: dict[str, Type],
@ -169,6 +200,10 @@ def parse_return_stmt(ctx: Context,
return {}, {}, True
ty = parse_expr(ctx, sym_table, node.value)
if ty != return_ty:
if isinstance(return_ty, TypeVariable):
if len(return_ty.constraints) == 1 and \
return_ty.constraints[0] == ty:
return {}, {}, True
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
return {}, {}, True

View File

@ -121,7 +121,7 @@ def parse_type_var(ctx: Context, node):
not isinstance(node.value.args[0].value, str):
raise CustomError('call to TypeVar must at least have a name',
node.value)
name = node.value.args[0]
name = node.targets[0].id
if name in ctx.variables:
raise CustomError('redefining type variable is not allowed', node)
constraints = []
@ -137,7 +137,7 @@ def parse_type_var(ctx: Context, node):
if value not in ctx.types:
raise CustomError(f'unbounded type {value}', node)
constraints.append(ctx.types[value])
ctx.variables[node.targets[0].id] = TypeVariable(name, constraints)
ctx.variables[name] = TypeVariable(node.value.args[0].value, constraints)
def parse_top_level(ctx: Context, module: ast.Module):