type guard
This commit is contained in:
parent
b8719e5214
commit
0453273d8b
|
@ -19,9 +19,7 @@ for simplicity reasons:
|
||||||
* `with`, `try except`, etc.
|
* `with`, `try except`, etc.
|
||||||
* const indexing with tuple.
|
* const indexing with tuple.
|
||||||
* method override check modulo variable renaming.
|
* method override check modulo variable renaming.
|
||||||
|
* more complicated type guard
|
||||||
These features are currently not implemented, and would be added in due course:
|
|
||||||
* Type guards.
|
|
||||||
|
|
||||||
|
|
||||||
## Type Check Implementation
|
## Type Check Implementation
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
I = TypeVar('I', int32, int64)
|
I = TypeVar('I', int32, Vec)
|
||||||
|
|
||||||
class Vec:
|
class Vec:
|
||||||
v: list[int32]
|
v: list[int32]
|
||||||
def __init__(self, v: list[int32]):
|
def __init__(self, v: list[int32]):
|
||||||
self.v = v
|
self.v = v
|
||||||
|
|
||||||
def __add__(self, other: int32) -> Vec:
|
def __add__(self, other: I) -> Vec:
|
||||||
return Vec([v + other for v in self.v])
|
if other is int32:
|
||||||
|
return Vec([v + other for v in self.v])
|
||||||
|
else:
|
||||||
|
return Vec([self.v[i] + other.v[i] for i in range(len(self.v))])
|
||||||
|
|
||||||
|
|
||||||
def addI(a: I, b: I) -> I:
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ def find_subst(ctx: dict[str, Type],
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
if isinstance(a, TypeVariable):
|
if isinstance(a, TypeVariable):
|
||||||
if a == b:
|
if len(a.constraints) == 1 and a.constraints[0] == b:
|
||||||
return sub
|
return sub
|
||||||
else:
|
else:
|
||||||
return f"{a} can take values other than {b}"
|
return f"{a} can take values other than {b}"
|
||||||
|
|
|
@ -107,6 +107,19 @@ def parse_attribute(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
node):
|
node):
|
||||||
obj = parse_expr(ctx, sym_table, node.value)
|
obj = parse_expr(ctx, sym_table, node.value)
|
||||||
|
|
||||||
|
if isinstance(obj, TypeVariable) and len(obj.constraints) > 0:
|
||||||
|
if node.attr not in obj.constraints[0].fields:
|
||||||
|
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||||
|
ty = obj.constraints[0].fields[node.attr]
|
||||||
|
for v in obj.constraints[1:]:
|
||||||
|
if node.attr not in v.fields:
|
||||||
|
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||||
|
if v.fields[node.attr] != ty:
|
||||||
|
raise CustomError(
|
||||||
|
f'unknown field {node.attr} in {obj} (type mismatch)', node)
|
||||||
|
return ty
|
||||||
|
|
||||||
if node.attr in obj.fields:
|
if node.attr in obj.fields:
|
||||||
return obj.fields[node.attr]
|
return obj.fields[node.attr]
|
||||||
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import copy
|
||||||
from helper import *
|
from helper import *
|
||||||
from type_def import *
|
from type_def import *
|
||||||
from inference import *
|
from inference import *
|
||||||
|
from top_level import parse_type
|
||||||
from parse_expr import parse_expr, parse_simple_binding
|
from parse_expr import parse_expr, parse_simple_binding
|
||||||
|
|
||||||
def parse_stmts(ctx: Context,
|
def parse_stmts(ctx: Context,
|
||||||
|
@ -113,14 +114,44 @@ def parse_if_stmt(ctx: Context,
|
||||||
used_sym_table: dict[str, Type],
|
used_sym_table: dict[str, Type],
|
||||||
return_ty: Type,
|
return_ty: Type,
|
||||||
node):
|
node):
|
||||||
test = parse_expr(ctx, sym_table, node.test)
|
if isinstance(node.test, ast.Compare) and \
|
||||||
if test != ctx.types['bool']:
|
len(node.test.ops) == 1 and \
|
||||||
raise CustomError(f'condition must be bool instead of {test}', node.test)
|
(isinstance(node.test.ops[0], ast.Is) or\
|
||||||
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
isinstance(node.test.ops[0], ast.IsNot)):
|
||||||
used_sym_table |= b
|
if not isinstance(node.test.left, ast.Name):
|
||||||
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
raise CustomError(
|
||||||
defined = {k: a[k] for k in a if k in a1}
|
'type guard only support testing variables',
|
||||||
return defined, b | b1, r and r1
|
node.test)
|
||||||
|
t = parse_expr(ctx, sym_table, node.test.left)
|
||||||
|
if not isinstance(t, TypeVariable) or len(t.constraints) < 2:
|
||||||
|
raise CustomError(
|
||||||
|
'type guard only support basic type variables with constraints',
|
||||||
|
node.test)
|
||||||
|
t1, _ = parse_type(ctx, node.test.comparators[0])
|
||||||
|
if t1 not in t.constraints:
|
||||||
|
raise CustomError(
|
||||||
|
f'{t1} is not in constraints of {t}',
|
||||||
|
node.test)
|
||||||
|
t2 = [v for v in t.constraints if v != t1]
|
||||||
|
try:
|
||||||
|
t.constraints = [t1]
|
||||||
|
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
||||||
|
used_sym_table |= b
|
||||||
|
t.constraints = t2
|
||||||
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
||||||
|
defined = {k: a[k] for k in a if k in a1}
|
||||||
|
return defined, b | b1, r and r1
|
||||||
|
finally:
|
||||||
|
t.constraints = t
|
||||||
|
else:
|
||||||
|
test = parse_expr(ctx, sym_table, node.test)
|
||||||
|
if test != ctx.types['bool']:
|
||||||
|
raise CustomError(f'condition must be bool instead of {test}', node.test)
|
||||||
|
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
||||||
|
used_sym_table |= b
|
||||||
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
||||||
|
defined = {k: a[k] for k in a if k in a1}
|
||||||
|
return defined, b | b1, r and r1
|
||||||
|
|
||||||
def parse_for_stmt(ctx: Context,
|
def parse_for_stmt(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
|
@ -169,6 +200,10 @@ def parse_return_stmt(ctx: Context,
|
||||||
return {}, {}, True
|
return {}, {}, True
|
||||||
ty = parse_expr(ctx, sym_table, node.value)
|
ty = parse_expr(ctx, sym_table, node.value)
|
||||||
if ty != return_ty:
|
if ty != return_ty:
|
||||||
|
if isinstance(return_ty, TypeVariable):
|
||||||
|
if len(return_ty.constraints) == 1 and \
|
||||||
|
return_ty.constraints[0] == ty:
|
||||||
|
return {}, {}, True
|
||||||
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
|
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
|
||||||
return {}, {}, True
|
return {}, {}, True
|
||||||
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ def parse_type_var(ctx: Context, node):
|
||||||
not isinstance(node.value.args[0].value, str):
|
not isinstance(node.value.args[0].value, str):
|
||||||
raise CustomError('call to TypeVar must at least have a name',
|
raise CustomError('call to TypeVar must at least have a name',
|
||||||
node.value)
|
node.value)
|
||||||
name = node.value.args[0]
|
name = node.targets[0].id
|
||||||
if name in ctx.variables:
|
if name in ctx.variables:
|
||||||
raise CustomError('redefining type variable is not allowed', node)
|
raise CustomError('redefining type variable is not allowed', node)
|
||||||
constraints = []
|
constraints = []
|
||||||
|
@ -137,7 +137,7 @@ def parse_type_var(ctx: Context, node):
|
||||||
if value not in ctx.types:
|
if value not in ctx.types:
|
||||||
raise CustomError(f'unbounded type {value}', node)
|
raise CustomError(f'unbounded type {value}', node)
|
||||||
constraints.append(ctx.types[value])
|
constraints.append(ctx.types[value])
|
||||||
ctx.variables[node.targets[0].id] = TypeVariable(name, constraints)
|
ctx.variables[name] = TypeVariable(node.value.args[0].value, constraints)
|
||||||
|
|
||||||
|
|
||||||
def parse_top_level(ctx: Context, module: ast.Module):
|
def parse_top_level(ctx: Context, module: ast.Module):
|
||||||
|
|
Loading…
Reference in New Issue