added simple lifetime check
This commit is contained in:
parent
3457e594ec
commit
84d09f1fd1
|
@ -1,30 +1,65 @@
|
|||
I = TypeVar('I', int32, float, Vec)
|
||||
T = TypeVar('T')
|
||||
|
||||
class Vec:
|
||||
v: list[int32]
|
||||
def __init__(self, v: list[int32]):
|
||||
self.v = v
|
||||
class Foo:
|
||||
a: list[int32]
|
||||
b: list[int32]
|
||||
|
||||
def __add__(self, other: I) -> Vec:
|
||||
if type(other) == int32:
|
||||
return Vec([v + other for v in self.v])
|
||||
elif type(other) == float:
|
||||
return Vec([v + int32(other) for v in self.v])
|
||||
else:
|
||||
return Vec([self.v[i] + other.v[i] for i in range(len(self.v))])
|
||||
|
||||
def get(self, index: int32) -> int32:
|
||||
return self.v.head()
|
||||
|
||||
|
||||
T = TypeVar('T', int32, list[int32])
|
||||
|
||||
def add(a: int32, b: T) -> int32:
|
||||
if type(b) == int32:
|
||||
return a + b
|
||||
else:
|
||||
for x in b:
|
||||
a = add(a, x)
|
||||
def choose(t: bool, a: T, b: T) -> T:
|
||||
if t:
|
||||
return a
|
||||
else:
|
||||
return b
|
||||
|
||||
def set_list(ls: list[T], a: T):
|
||||
# this should fail
|
||||
l2 = ls
|
||||
l2[-1] = a
|
||||
|
||||
def get_foo(a: Foo) -> list[int32]:
|
||||
return a.a
|
||||
|
||||
def set_foo(a: Foo, b: Foo):
|
||||
a.a[0] = b.a[0]
|
||||
if True:
|
||||
c = b
|
||||
# this should fail
|
||||
c.a = a.a
|
||||
|
||||
def set_foo2(a: Foo, b: Foo):
|
||||
a.a[0] = b.a[0]
|
||||
if True:
|
||||
c = [Foo()]
|
||||
c[0] = a
|
||||
# this should fail
|
||||
c[0].a = b.a
|
||||
|
||||
def set_foo3(a: Foo, b: Foo):
|
||||
a.a[0] = b.a[0]
|
||||
if True:
|
||||
c = [Foo()]
|
||||
c[0] = a
|
||||
# this should fail
|
||||
c[0].a = get_foo(b)
|
||||
|
||||
def set_foo4(a: Foo, b: Foo):
|
||||
a.a[0] = b.a[0]
|
||||
if True:
|
||||
c = [Foo()]
|
||||
d = c
|
||||
d[0] = a
|
||||
# this should fail
|
||||
c[0].a = get_foo(b)
|
||||
|
||||
def set_foo5(a: Foo, b: Foo):
|
||||
a.a[0] = b.a[0]
|
||||
if True:
|
||||
c = [Foo()]
|
||||
d = c
|
||||
e = d
|
||||
f = e
|
||||
f[0] = a
|
||||
# this should fail
|
||||
c[0].a = get_foo(b)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
import ast
|
||||
from type_def import PrimitiveType
|
||||
from parse_expr import parse_expr
|
||||
|
||||
class Lifetime:
|
||||
low: int
|
||||
original: int
|
||||
|
||||
def __init__(self, scope):
|
||||
self.low = self.original = scope
|
||||
self.parent = None
|
||||
|
||||
def fold(self):
|
||||
while self.parent is not None:
|
||||
self.low = self.parent.low
|
||||
self.original = self.parent.original
|
||||
self.parent = self.parent.parent
|
||||
return self
|
||||
|
||||
def ok(self, other):
|
||||
self.fold()
|
||||
if other == None:
|
||||
return False
|
||||
other.fold()
|
||||
return self.low >= other.original and \
|
||||
(other.original != self.low or self.low != 1)
|
||||
|
||||
def __str__(self):
|
||||
self.fold()
|
||||
return f'({self.low}, {self.original})'
|
||||
|
||||
|
||||
def assign_expr(
|
||||
scope: int,
|
||||
sym_table: dict[str, Lifetime],
|
||||
expr: ast.expr):
|
||||
if isinstance(expr, ast.Expression):
|
||||
body = expr.body
|
||||
else:
|
||||
body = expr
|
||||
if isinstance(body.type, PrimitiveType):
|
||||
body.lifetime = None
|
||||
elif isinstance(body, ast.Attribute):
|
||||
body.lifetime = assign_expr(scope, sym_table, body.value)
|
||||
elif isinstance(body, ast.Subscript):
|
||||
body.lifetime = assign_expr(scope, sym_table, body.value)
|
||||
elif isinstance(body, ast.Name):
|
||||
if body.id in sym_table:
|
||||
body.lifetime = sym_table[body.id]
|
||||
else:
|
||||
body.lifetime = Lifetime(scope)
|
||||
sym_table[body.id] = body.lifetime
|
||||
else:
|
||||
body.lifetime = Lifetime(scope)
|
||||
return body.lifetime
|
||||
|
||||
|
||||
def assign_stmt(
|
||||
scope: int,
|
||||
sym_table: dict[str, Lifetime],
|
||||
nodes):
|
||||
for node in nodes:
|
||||
if isinstance(node, ast.Assign):
|
||||
b = assign_expr(scope, sym_table, node.value)
|
||||
for target in node.targets:
|
||||
a = assign_expr(scope, sym_table, target)
|
||||
if a == None and b == None:
|
||||
continue
|
||||
if not a.ok(b):
|
||||
print(ast.unparse(node))
|
||||
print(f'{a} <- {b}')
|
||||
assert False
|
||||
a.low = min(a.low, b.low)
|
||||
a.original = max(a.original, b.original)
|
||||
b.parent = a
|
||||
elif isinstance(node, ast.If) or isinstance(node, ast.While):
|
||||
assign_stmt(scope + 1, sym_table, node.body)
|
||||
assign_stmt(scope + 1, sym_table, node.orelse)
|
||||
elif isinstance(node, ast.Return):
|
||||
a = assign_expr(scope, sym_table, node.value)
|
||||
if a != None and a.fold().original > 1:
|
||||
print(ast.unparse(node))
|
||||
print(a)
|
||||
assert False
|
||||
|
||||
|
|
@ -7,6 +7,7 @@ from parse_stmt import parse_stmts
|
|||
from primitives import simplest_ctx
|
||||
from top_level import parse_top_level
|
||||
from inheritance import class_fixup
|
||||
from lifetime import Lifetime, assign_stmt
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print('please pass the python script name as argument')
|
||||
|
@ -40,9 +41,20 @@ try:
|
|||
if isinstance(ty, SelfType):
|
||||
ty = ctx.types[c]
|
||||
sym_table[n.arg] = ty.subst(subst)
|
||||
_, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body)
|
||||
if result is not None and not returned:
|
||||
raise CustomError('Function may have no return value', fn)
|
||||
try:
|
||||
print()
|
||||
print('checking:')
|
||||
print(ast.unparse(fn))
|
||||
print('typecheck...')
|
||||
_, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body)
|
||||
if result is not None and not returned:
|
||||
raise CustomError('Function may have no return value', fn)
|
||||
print('lifetime check...')
|
||||
sym_table = {k: Lifetime(1) for k in sym_table}
|
||||
assign_stmt(2, sym_table, fn.body)
|
||||
print('OK!')
|
||||
except AssertionError:
|
||||
pass
|
||||
except CustomError as e:
|
||||
print('Error while type checking:')
|
||||
print(e.msg)
|
||||
|
|
|
@ -17,33 +17,35 @@ def parse_expr(ctx: Context,
|
|||
else:
|
||||
body = expr
|
||||
if isinstance(body, ast.Constant):
|
||||
return parse_constant(ctx, sym_table, body)
|
||||
if isinstance(body, ast.UnaryOp):
|
||||
return parse_unary_ops(ctx, sym_table, body)
|
||||
if isinstance(body, ast.BinOp):
|
||||
return parse_bin_ops(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Name):
|
||||
return parse_name(ctx, sym_table, body)
|
||||
if isinstance(body, ast.List):
|
||||
return parse_list(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Tuple):
|
||||
return parse_tuple(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Attribute):
|
||||
return parse_attribute(ctx, sym_table, body)
|
||||
if isinstance(body, ast.BoolOp):
|
||||
return parse_bool_ops(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Compare):
|
||||
return parse_compare(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Call):
|
||||
return parse_call(ctx, sym_table, body)
|
||||
if isinstance(body, ast.Subscript):
|
||||
return parse_subscript(ctx, sym_table, body)
|
||||
if isinstance(body, ast.IfExp):
|
||||
return parse_if_expr(ctx, sym_table, body)
|
||||
if isinstance(body, ast.ListComp):
|
||||
return parse_list_comprehension(ctx, sym_table, body)
|
||||
raise CustomError(f'{body} is not yet supported', body)
|
||||
|
||||
result = parse_constant(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.UnaryOp):
|
||||
result = parse_unary_ops(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.BinOp):
|
||||
result = parse_bin_ops(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Name):
|
||||
result = parse_name(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.List):
|
||||
result = parse_list(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Tuple):
|
||||
result = parse_tuple(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Attribute):
|
||||
result = parse_attribute(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.BoolOp):
|
||||
result = parse_bool_ops(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Compare):
|
||||
result = parse_compare(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Call):
|
||||
result = parse_call(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.Subscript):
|
||||
result = parse_subscript(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.IfExp):
|
||||
result = parse_if_expr(ctx, sym_table, body)
|
||||
elif isinstance(body, ast.ListComp):
|
||||
result = parse_list_comprehension(ctx, sym_table, body)
|
||||
else:
|
||||
raise CustomError(f'{body} is not yet supported', body)
|
||||
body.type = result
|
||||
return result
|
||||
|
||||
def get_unary_op(op):
|
||||
if isinstance(op, ast.UAdd):
|
||||
|
@ -238,6 +240,7 @@ def parse_simple_binding(name, ty):
|
|||
if isinstance(name, ast.Name):
|
||||
if name.id == '_':
|
||||
return {}
|
||||
name.type = ty
|
||||
return {name.id: ty}
|
||||
elif isinstance(name, ast.Tuple):
|
||||
if not isinstance(ty, TupleType):
|
||||
|
|
|
@ -50,15 +50,18 @@ def get_target_type(ctx: Context,
|
|||
i = parse_expr(ctx, sym_table, target.slice)
|
||||
if i != ctx.types['int32']:
|
||||
raise CustomError(f'index must be int32', target.slice)
|
||||
target.type = t.params[0]
|
||||
return t.params[0]
|
||||
elif isinstance(target, ast.Attribute):
|
||||
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
|
||||
if target.attr not in t.fields:
|
||||
raise CustomError(f'{t} has no field {target.attr}', target)
|
||||
target.type = t.fields[target.attr]
|
||||
return t.fields[target.attr]
|
||||
elif isinstance(target, ast.Name):
|
||||
if target.id not in sym_table:
|
||||
raise CustomError(f'unbounded {target.id}', target)
|
||||
target.type = sym_table[target.id]
|
||||
return sym_table[target.id]
|
||||
else:
|
||||
raise CustomError(f'assignment to {target} is not supported', target)
|
||||
|
@ -76,6 +79,7 @@ def parse_stmt_binding(ctx: Context,
|
|||
f'but is now {ty}', target)
|
||||
if target.id == '_':
|
||||
return {}
|
||||
target.type = ty
|
||||
return {target.id: ty}
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(ty, TupleType):
|
||||
|
|
Loading…
Reference in New Issue