type guard
This commit is contained in:
parent
b8719e5214
commit
0453273d8b
|
@ -19,9 +19,7 @@ for simplicity reasons:
|
|||
* `with`, `try except`, etc.
|
||||
* const indexing with tuple.
|
||||
* method override check modulo variable renaming.
|
||||
|
||||
These features are currently not implemented, and would be added in due course:
|
||||
* Type guards.
|
||||
* more complicated type guard
|
||||
|
||||
|
||||
## Type Check Implementation
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
I = TypeVar('I', int32, int64)
|
||||
I = TypeVar('I', int32, Vec)
|
||||
|
||||
class Vec:
|
||||
v: list[int32]
|
||||
def __init__(self, v: list[int32]):
|
||||
self.v = v
|
||||
|
||||
def __add__(self, other: int32) -> Vec:
|
||||
def __add__(self, other: I) -> Vec:
|
||||
if other is int32:
|
||||
return Vec([v + other for v in self.v])
|
||||
else:
|
||||
return Vec([self.v[i] + other.v[i] for i in range(len(self.v))])
|
||||
|
||||
|
||||
def addI(a: I, b: I) -> I:
|
||||
return a + b
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ def find_subst(ctx: dict[str, Type],
|
|||
return sub
|
||||
|
||||
if isinstance(a, TypeVariable):
|
||||
if a == b:
|
||||
if len(a.constraints) == 1 and a.constraints[0] == b:
|
||||
return sub
|
||||
else:
|
||||
return f"{a} can take values other than {b}"
|
||||
|
|
|
@ -107,6 +107,19 @@ def parse_attribute(ctx: Context,
|
|||
sym_table: dict[str, Type],
|
||||
node):
|
||||
obj = parse_expr(ctx, sym_table, node.value)
|
||||
|
||||
if isinstance(obj, TypeVariable) and len(obj.constraints) > 0:
|
||||
if node.attr not in obj.constraints[0].fields:
|
||||
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||
ty = obj.constraints[0].fields[node.attr]
|
||||
for v in obj.constraints[1:]:
|
||||
if node.attr not in v.fields:
|
||||
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||
if v.fields[node.attr] != ty:
|
||||
raise CustomError(
|
||||
f'unknown field {node.attr} in {obj} (type mismatch)', node)
|
||||
return ty
|
||||
|
||||
if node.attr in obj.fields:
|
||||
return obj.fields[node.attr]
|
||||
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||
|
|
|
@ -3,6 +3,7 @@ import copy
|
|||
from helper import *
|
||||
from type_def import *
|
||||
from inference import *
|
||||
from top_level import parse_type
|
||||
from parse_expr import parse_expr, parse_simple_binding
|
||||
|
||||
def parse_stmts(ctx: Context,
|
||||
|
@ -113,6 +114,36 @@ def parse_if_stmt(ctx: Context,
|
|||
used_sym_table: dict[str, Type],
|
||||
return_ty: Type,
|
||||
node):
|
||||
if isinstance(node.test, ast.Compare) and \
|
||||
len(node.test.ops) == 1 and \
|
||||
(isinstance(node.test.ops[0], ast.Is) or\
|
||||
isinstance(node.test.ops[0], ast.IsNot)):
|
||||
if not isinstance(node.test.left, ast.Name):
|
||||
raise CustomError(
|
||||
'type guard only support testing variables',
|
||||
node.test)
|
||||
t = parse_expr(ctx, sym_table, node.test.left)
|
||||
if not isinstance(t, TypeVariable) or len(t.constraints) < 2:
|
||||
raise CustomError(
|
||||
'type guard only support basic type variables with constraints',
|
||||
node.test)
|
||||
t1, _ = parse_type(ctx, node.test.comparators[0])
|
||||
if t1 not in t.constraints:
|
||||
raise CustomError(
|
||||
f'{t1} is not in constraints of {t}',
|
||||
node.test)
|
||||
t2 = [v for v in t.constraints if v != t1]
|
||||
try:
|
||||
t.constraints = [t1]
|
||||
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
||||
used_sym_table |= b
|
||||
t.constraints = t2
|
||||
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
||||
defined = {k: a[k] for k in a if k in a1}
|
||||
return defined, b | b1, r and r1
|
||||
finally:
|
||||
t.constraints = t
|
||||
else:
|
||||
test = parse_expr(ctx, sym_table, node.test)
|
||||
if test != ctx.types['bool']:
|
||||
raise CustomError(f'condition must be bool instead of {test}', node.test)
|
||||
|
@ -169,6 +200,10 @@ def parse_return_stmt(ctx: Context,
|
|||
return {}, {}, True
|
||||
ty = parse_expr(ctx, sym_table, node.value)
|
||||
if ty != return_ty:
|
||||
if isinstance(return_ty, TypeVariable):
|
||||
if len(return_ty.constraints) == 1 and \
|
||||
return_ty.constraints[0] == ty:
|
||||
return {}, {}, True
|
||||
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
|
||||
return {}, {}, True
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ def parse_type_var(ctx: Context, node):
|
|||
not isinstance(node.value.args[0].value, str):
|
||||
raise CustomError('call to TypeVar must at least have a name',
|
||||
node.value)
|
||||
name = node.value.args[0]
|
||||
name = node.targets[0].id
|
||||
if name in ctx.variables:
|
||||
raise CustomError('redefining type variable is not allowed', node)
|
||||
constraints = []
|
||||
|
@ -137,7 +137,7 @@ def parse_type_var(ctx: Context, node):
|
|||
if value not in ctx.types:
|
||||
raise CustomError(f'unbounded type {value}', node)
|
||||
constraints.append(ctx.types[value])
|
||||
ctx.variables[node.targets[0].id] = TypeVariable(name, constraints)
|
||||
ctx.variables[name] = TypeVariable(node.value.args[0].value, constraints)
|
||||
|
||||
|
||||
def parse_top_level(ctx: Context, module: ast.Module):
|
||||
|
|
Loading…
Reference in New Issue