implementing function call check
This commit is contained in:
parent
fe6d9cc446
commit
ce031999e9
|
@ -0,0 +1,9 @@
|
||||||
|
class CustomError(Exception):
|
||||||
|
def __init__(self, msg):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
def stringify_subst(subst):
|
||||||
|
if isinstance(subst, str):
|
||||||
|
return subst
|
||||||
|
elements = [f"{key}: {str(value)}" for key, value in subst.items()]
|
||||||
|
return "{" + ', '.join(elements) + "}"
|
|
@ -1,12 +1,15 @@
|
||||||
from type_def import *
|
from type_def import *
|
||||||
|
from helper import *
|
||||||
|
|
||||||
def find_subst(ctx: dict[str, Type],
|
def find_subst(ctx: dict[str, Type],
|
||||||
sub: dict[str, Type],
|
sub: dict[str, Type],
|
||||||
a: Type,
|
a: Type,
|
||||||
b: Type):
|
b: Type):
|
||||||
"""
|
"""
|
||||||
Find substitution s such that ctx(a) = s(sub(ctx(b)))
|
Find substitution s such that ctx(a) = s(sub(b)).
|
||||||
return error message if type mismatch
|
Note that variables in a and b are considered independent.
|
||||||
|
return s.sub if such s exists (. means function composition).
|
||||||
|
return error message if type mismatch.
|
||||||
"""
|
"""
|
||||||
# is error
|
# is error
|
||||||
if isinstance(sub, str):
|
if isinstance(sub, str):
|
||||||
|
@ -16,9 +19,7 @@ def find_subst(ctx: dict[str, Type],
|
||||||
a = ctx[a.name]
|
a = ctx[a.name]
|
||||||
|
|
||||||
if isinstance(b, TypeVariable):
|
if isinstance(b, TypeVariable):
|
||||||
if b.name in ctx:
|
if b.name in sub:
|
||||||
b = ctx[b.name]
|
|
||||||
elif b.name in sub:
|
|
||||||
b = sub[b.name]
|
b = sub[b.name]
|
||||||
else:
|
else:
|
||||||
if len(b.constraints) > 0:
|
if len(b.constraints) > 0:
|
||||||
|
@ -32,10 +33,6 @@ def find_subst(ctx: dict[str, Type],
|
||||||
else:
|
else:
|
||||||
if a not in b.constraints:
|
if a not in b.constraints:
|
||||||
return f"{b} cannot take value of {a}"
|
return f"{b} cannot take value of {a}"
|
||||||
if a == b:
|
|
||||||
return sub
|
|
||||||
if b in a.get_vars():
|
|
||||||
return "Recursive type is not supported"
|
|
||||||
sub[b.name] = a
|
sub[b.name] = a
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
|
@ -47,8 +44,12 @@ def find_subst(ctx: dict[str, Type],
|
||||||
|
|
||||||
if isinstance(a, BotType):
|
if isinstance(a, BotType):
|
||||||
return sub
|
return sub
|
||||||
|
# TODO: virtual type is not handled currently
|
||||||
|
# we need to access the class dictionary to handle this
|
||||||
if type(a) == type(b):
|
if type(a) == type(b):
|
||||||
if isinstance(a, ParametricType):
|
if isinstance(a, ParametricType):
|
||||||
|
if len(a.params) != len(b.params):
|
||||||
|
return f"{a} != {b}"
|
||||||
old = sub
|
old = sub
|
||||||
for x, y in zip(a.params, b.params):
|
for x, y in zip(a.params, b.params):
|
||||||
old = find_subst(ctx, old, x, y)
|
old = find_subst(ctx, old, x, y)
|
||||||
|
@ -62,4 +63,86 @@ def find_subst(ctx: dict[str, Type],
|
||||||
raise Exception()
|
raise Exception()
|
||||||
return f"{a} != {b}"
|
return f"{a} != {b}"
|
||||||
|
|
||||||
|
def resolve_call(obj,
|
||||||
|
fn: str,
|
||||||
|
args: list[Type],
|
||||||
|
assumptions: dict[str, Type],
|
||||||
|
ctx: Context) -> tuple[Type]:
|
||||||
|
# TODO: we may want to return the substitution, for monomorphization...
|
||||||
|
f_args = None
|
||||||
|
f_result = None
|
||||||
|
if obj is not None:
|
||||||
|
obj = obj.subst(assumptions)
|
||||||
|
if obj is None:
|
||||||
|
if fn in ctx.functions:
|
||||||
|
f = ctx.functions[fn]
|
||||||
|
f_args, f_result = TupleType(f[0]), f[1]
|
||||||
|
elif fn in ctx.types:
|
||||||
|
c = ctx.types[fn]
|
||||||
|
if '__init__' in c.methods:
|
||||||
|
f = c.methods['__init__']
|
||||||
|
if not isinstance(f[0][0], SelfType) or f[1] is not None:
|
||||||
|
raise CustomError(
|
||||||
|
f'__init__ of {c} should accept self and return None'
|
||||||
|
)
|
||||||
|
f_args, f_result = TupleType(f[0][1:]), c
|
||||||
|
else:
|
||||||
|
f_args, f_result = TupleType([]), c
|
||||||
|
else:
|
||||||
|
raise CustomError(f"No such function {fn}")
|
||||||
|
elif isinstance(obj, PrimitiveType) or isinstance(obj, ClassType):
|
||||||
|
if fn in obj.methods:
|
||||||
|
f = obj.methods[fn]
|
||||||
|
if len(f[0]) == 0 or (not isinstance(f[0][0], SelfType) and \
|
||||||
|
f[0][0] != obj):
|
||||||
|
raise CustomError('{f} is not a method of {obj}')
|
||||||
|
f_args, f_result = TupleType(f[0][1:]), f[1]
|
||||||
|
else:
|
||||||
|
raise CustomError(f"No such method {fn} in {c}")
|
||||||
|
elif isinstance(obj, VirtualClassType):
|
||||||
|
# may need to emit special annotation that this is a virtual method
|
||||||
|
# call?
|
||||||
|
if fn in obj.base.methods:
|
||||||
|
f = obj.base.methods[fn]
|
||||||
|
if len(f[0]) == 0 or not isinstance(f[0][0], SelfType):
|
||||||
|
raise CustomError('{f} is not a method of {obj}')
|
||||||
|
f_args, f_result = TupleType(f[0][1:]), f[1]
|
||||||
|
else:
|
||||||
|
raise CustomError(f"No such method {fn} in {c}")
|
||||||
|
elif isinstance(obj, TypeVariable):
|
||||||
|
# if not constrained, error. otherwise, try all values, and only allow
|
||||||
|
# if the results are the same or if they are the same modulo the
|
||||||
|
# substitution.
|
||||||
|
# expensive operation, but cache should be applicable
|
||||||
|
# in order to cache this, our cache must be able to compare equality
|
||||||
|
# modulo variable naming... probably not easy either
|
||||||
|
if len(obj.constraints) == 0:
|
||||||
|
raise CustomError("no methods for unconstrained object")
|
||||||
|
results = [resolve_call(obj, fn, args, assumptions | {obj.name: v}, ctx)
|
||||||
|
for v in obj.assumptions]
|
||||||
|
for v in results[1:]:
|
||||||
|
if v != results[0]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# same result
|
||||||
|
return results[0]
|
||||||
|
results = [v.inv_subst([a, obj])
|
||||||
|
for v, a in zip(results, obj.assumptions)]
|
||||||
|
for v in results[1:]:
|
||||||
|
if v != results[0]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# same result
|
||||||
|
return results[0]
|
||||||
|
raise CustomError("Divergent type after constraints substitution")
|
||||||
|
|
||||||
|
a = TupleType(args)
|
||||||
|
subst = find_subst(assumptions, {}, a, f_args)
|
||||||
|
if isinstance(subst, str):
|
||||||
|
raise CustomError(f"type check failed: {subst}")
|
||||||
|
result = f_result.subst(subst)
|
||||||
|
if isinstance(result, SelfType):
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
from type_def import *
|
||||||
|
from inference import *
|
||||||
|
from helper import *
|
||||||
|
|
||||||
|
types = {
|
||||||
|
'int32': PrimitiveType('int32'),
|
||||||
|
'int64': PrimitiveType('int64'),
|
||||||
|
'str': PrimitiveType('str'),
|
||||||
|
}
|
||||||
|
|
||||||
|
i32 = types['int32']
|
||||||
|
i64 = types['int64']
|
||||||
|
s = types['str']
|
||||||
|
|
||||||
|
|
||||||
|
variables = {
|
||||||
|
'X': TypeVariable('X', []),
|
||||||
|
'Y': TypeVariable('Y', []),
|
||||||
|
'I': TypeVariable('I', [i32, i64]),
|
||||||
|
'A': TypeVariable('A', [i32, i64, s]),
|
||||||
|
}
|
||||||
|
|
||||||
|
X = variables['X']
|
||||||
|
Y = variables['Y']
|
||||||
|
I = variables['I']
|
||||||
|
A = variables['A']
|
||||||
|
|
||||||
|
i32.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
|
i32.methods['__add__'] = ([SelfType(), i32], i32, set())
|
||||||
|
i32.methods['__sub__'] = ([SelfType(), i32], i32, set())
|
||||||
|
|
||||||
|
i64.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
|
i64.methods['__add__'] = ([SelfType(), i64], i64, set())
|
||||||
|
i64.methods['__sub__'] = ([SelfType(), i64], i64, set())
|
||||||
|
|
||||||
|
ctx = Context(variables, types)
|
||||||
|
|
||||||
|
|
||||||
|
def test_call(obj, fn, args, assumptions = {}):
|
||||||
|
args_str = ', '.join([str(v) for v in args])
|
||||||
|
obj_str = '' if obj is None else str(obj) + '.'
|
||||||
|
print(f'Testing {obj_str}{fn}({args_str}) w.r.t. {stringify_subst(assumptions)}')
|
||||||
|
try:
|
||||||
|
result = resolve_call(obj, fn, args, assumptions, ctx)
|
||||||
|
print(result)
|
||||||
|
except CustomError as err:
|
||||||
|
print(f'error: {err.msg}')
|
||||||
|
|
||||||
|
test_call(None, 'int32', [])
|
||||||
|
test_call(None, 'int32', [i32])
|
||||||
|
test_call(None, 'int32', [i64])
|
||||||
|
test_call(None, 'int32', [I])
|
||||||
|
test_call(None, 'int32', [A])
|
||||||
|
test_call(None, 'int32', [i32, i64])
|
||||||
|
test_call(i32, '__add__', [])
|
||||||
|
test_call(i32, '__add__', [i32])
|
||||||
|
test_call(i32, '__add__', [i64])
|
||||||
|
test_call(i32, '__add__', [i32, i32])
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from type_def import *
|
from type_def import *
|
||||||
from inference import *
|
from inference import *
|
||||||
|
from helper import *
|
||||||
|
|
||||||
types = {
|
types = {
|
||||||
'int32': PrimitiveType('int32'),
|
'int32': PrimitiveType('int32'),
|
||||||
|
@ -49,6 +50,7 @@ try_case(A, I, {})
|
||||||
try_case(X, I, {})
|
try_case(X, I, {})
|
||||||
try_case(ListType(i32), TupleType([i32]), {})
|
try_case(ListType(i32), TupleType([i32]), {})
|
||||||
try_case(TupleType([i32]), ListType(i32), {})
|
try_case(TupleType([i32]), ListType(i32), {})
|
||||||
|
try_case(TupleType([i32, i32]), TupleType([i32]), {})
|
||||||
try_case(ListType(i32), ListType(i32), {})
|
try_case(ListType(i32), ListType(i32), {})
|
||||||
try_case(TupleType([X, X]), TupleType([X, Y]), {})
|
try_case(TupleType([X, X]), TupleType([X, Y]), {})
|
||||||
try_case(TupleType([X, X]), TupleType([Y, Y]), {})
|
try_case(TupleType([X, X]), TupleType([Y, Y]), {})
|
||||||
|
@ -57,4 +59,3 @@ try_case(TupleType([X, X]), TupleType([X, X]), {})
|
||||||
try_case(TupleType([X, Y]), X, {})
|
try_case(TupleType([X, Y]), X, {})
|
||||||
try_case(TupleType([i32, Y]), X, {})
|
try_case(TupleType([i32, Y]), X, {})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ class A:
|
||||||
|
|
||||||
class B(A):
|
class B(A):
|
||||||
a: str
|
a: str
|
||||||
def bar(a: list[list[virtual[A]]]) -> A:
|
def bar(self, a: list[list[virtual[A]]]) -> self:
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ variables = {'X': TypeVariable('X', []), 'Y': TypeVariable('Y', [])}
|
||||||
types = {'int': PrimitiveType('int'), 'str': PrimitiveType('str')}
|
types = {'int': PrimitiveType('int'), 'str': PrimitiveType('str')}
|
||||||
ctx = Context(variables, types)
|
ctx = Context(variables, types)
|
||||||
|
|
||||||
ctx, functions, _ = parse_top_level(ctx, ast.parse(test))
|
ctx, _ = parse_top_level(ctx, ast.parse(test))
|
||||||
|
|
||||||
for name, t in ctx.types.items():
|
for name, t in ctx.types.items():
|
||||||
if isinstance(t, ClassType):
|
if isinstance(t, ClassType):
|
||||||
|
|
|
@ -1,17 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
from type_def import *
|
from type_def import *
|
||||||
|
from helper import *
|
||||||
class CustomError(Exception):
|
|
||||||
def __init__(self, msg):
|
|
||||||
self.msg = msg
|
|
||||||
|
|
||||||
class Context:
|
|
||||||
variables: dict[str, TypeVariable]
|
|
||||||
types: dict[Type]
|
|
||||||
|
|
||||||
def __init__(self, variables, types):
|
|
||||||
self.variables = variables
|
|
||||||
self.types = types
|
|
||||||
|
|
||||||
|
|
||||||
def parse_type(ctx: Context, ty):
|
def parse_type(ctx: Context, ty):
|
||||||
|
@ -65,9 +54,13 @@ def parse_function(ctx: Context, base, fn: ast.FunctionDef):
|
||||||
ty, v = parse_type(ctx, arg.annotation)
|
ty, v = parse_type(ctx, arg.annotation)
|
||||||
var |= v
|
var |= v
|
||||||
if name == 'self' and ty is None and base is not None:
|
if name == 'self' and ty is None and base is not None:
|
||||||
ty = base
|
ty = SelfType()
|
||||||
args.append(ty)
|
args.append(ty)
|
||||||
result, v = parse_type(ctx, fn.returns)
|
if isinstance(fn.returns, ast.Name) and fn.returns.id == 'self'\
|
||||||
|
and base is not None:
|
||||||
|
result, v = SelfType(), set()
|
||||||
|
else:
|
||||||
|
result, v = parse_type(ctx, fn.returns)
|
||||||
if len(v - var) > 0:
|
if len(v - var) > 0:
|
||||||
raise CustomError(f"Unbounded variable in return type of {fn.name}")
|
raise CustomError(f"Unbounded variable in return type of {fn.name}")
|
||||||
return args, result, var
|
return args, result, var
|
||||||
|
@ -126,7 +119,6 @@ def parse_top_level(ctx: Context, module: ast.Module):
|
||||||
to_be_processed.append(element)
|
to_be_processed.append(element)
|
||||||
|
|
||||||
# second pass, obtain all function types
|
# second pass, obtain all function types
|
||||||
functions = {}
|
|
||||||
function_stmts = []
|
function_stmts = []
|
||||||
for element in to_be_processed:
|
for element in to_be_processed:
|
||||||
if isinstance(element, ast.ClassDef):
|
if isinstance(element, ast.ClassDef):
|
||||||
|
@ -135,10 +127,12 @@ def parse_top_level(ctx: Context, module: ast.Module):
|
||||||
name = element.name
|
name = element.name
|
||||||
if name in functions:
|
if name in functions:
|
||||||
raise CustomError(f"Duplicated function name {name}")
|
raise CustomError(f"Duplicated function name {name}")
|
||||||
|
if name in ctx.types:
|
||||||
|
raise CustomError(f"Function name {name} clashed with type name")
|
||||||
args, result, var = parse_function(ctx, None, element)
|
args, result, var = parse_function(ctx, None, element)
|
||||||
functions[name] = (args, result, var)
|
ctx.functions[name] = (args, result, var)
|
||||||
function_stmts += element
|
function_stmts += element
|
||||||
|
|
||||||
return ctx, functions, function_stmts
|
return ctx, function_stmts
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,28 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
class Type:
|
class Type:
|
||||||
|
methods: dict[str, tuple[list['Type'], 'Type', set[str]]]
|
||||||
|
fields: dict[str, 'Type']
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.methods = {}
|
||||||
|
self.fields = {}
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_vars(self):
|
def get_vars(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def subst(self, subst: dict[str, 'Type']):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def inv_subst(self, subst: list[tuple['Type', 'TypeVariable']]):
|
||||||
|
for t, tv in subst:
|
||||||
|
if self == t:
|
||||||
|
return tv
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class BotType:
|
class BotType:
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -15,6 +33,7 @@ class PrimitiveType(Type):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -32,6 +51,7 @@ class TypeVariable(Type):
|
||||||
constraints: list[Type]
|
constraints: list[Type]
|
||||||
|
|
||||||
def __init__(self, name: str, constraints: list[Type]):
|
def __init__(self, name: str, constraints: list[Type]):
|
||||||
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.constraints = constraints
|
self.constraints = constraints
|
||||||
|
|
||||||
|
@ -44,6 +64,11 @@ class TypeVariable(Type):
|
||||||
def get_vars(self):
|
def get_vars(self):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
|
def subst(self, subst: dict[str, Type]):
|
||||||
|
if self.name in subst:
|
||||||
|
return subst[self.name]
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ClassType(Type):
|
class ClassType(Type):
|
||||||
name: str
|
name: str
|
||||||
|
@ -52,10 +77,9 @@ class ClassType(Type):
|
||||||
fields: dict[str, Type]
|
fields: dict[str, Type]
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.parents = []
|
self.parents = []
|
||||||
self.methods = {}
|
|
||||||
self.fields = {}
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
@ -63,10 +87,17 @@ class ClassType(Type):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, ClassType) and self.name == other.name
|
return isinstance(other, ClassType) and self.name == other.name
|
||||||
|
|
||||||
|
|
||||||
|
class SelfType(Type):
|
||||||
|
def __str__(self):
|
||||||
|
return 'self'
|
||||||
|
|
||||||
|
|
||||||
class VirtualClassType(Type):
|
class VirtualClassType(Type):
|
||||||
base: ClassType
|
base: ClassType
|
||||||
|
|
||||||
def __init__(self, base: ClassType):
|
def __init__(self, base: ClassType):
|
||||||
|
super().__init__()
|
||||||
self.base = base
|
self.base = base
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -79,6 +110,7 @@ class ParametricType(Type):
|
||||||
params: list[Type]
|
params: list[Type]
|
||||||
|
|
||||||
def __init__(self, params: list[Type]):
|
def __init__(self, params: list[Type]):
|
||||||
|
super().__init__()
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -98,6 +130,19 @@ class ParametricType(Type):
|
||||||
result.append(v)
|
result.append(v)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def subst(self, subst: dict[str, Type]):
|
||||||
|
s = copy.copy(self)
|
||||||
|
s.params = [v.subst(subst) for v in self.params]
|
||||||
|
return s
|
||||||
|
|
||||||
|
def inv_subst(self, subst: list[tuple['Type', 'TypeVariable']]):
|
||||||
|
for t, tv in subst:
|
||||||
|
if self == t:
|
||||||
|
return tv
|
||||||
|
s = copy.copy(self)
|
||||||
|
s.params = [v.inv_subst(subst) for v in self.params]
|
||||||
|
return s
|
||||||
|
|
||||||
class ListType(ParametricType):
|
class ListType(ParametricType):
|
||||||
def __init__(self, param: Type):
|
def __init__(self, param: Type):
|
||||||
super().__init__([param])
|
super().__init__([param])
|
||||||
|
@ -114,3 +159,13 @@ class TupleType(ParametricType):
|
||||||
return f"tuple[{', '.join([str(v) for v in self.params])}]"
|
return f"tuple[{', '.join([str(v) for v in self.params])}]"
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
variables: dict[str, TypeVariable]
|
||||||
|
types: dict[str, Type]
|
||||||
|
functions: dict[str, tuple[list[Type], Type, set[str]]]
|
||||||
|
|
||||||
|
def __init__(self, variables, types):
|
||||||
|
self.variables = variables
|
||||||
|
self.types = types
|
||||||
|
self.functions = {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue