Implement Call.

This commit is contained in:
whitequark 2015-06-15 16:55:13 +03:00
parent 7a00a4a47f
commit 8c5e58f83c
4 changed files with 121 additions and 2 deletions

View File

@ -152,6 +152,9 @@ class TFunction(Type):
def __init__(self, args, optargs, ret): def __init__(self, args, optargs, ret):
self.args, self.optargs, self.ret = args, optargs, ret self.args, self.optargs, self.ret = args, optargs, ret
def arity(self):
return len(self.args) + len(self.optargs)
def find(self): def find(self):
return self return self
@ -229,6 +232,9 @@ def is_tuple(typ, elts=None):
else: else:
return isinstance(typ, TTuple) return isinstance(typ, TTuple)
def is_function(typ):
return isinstance(typ.find(), TFunction)
def get_value(typ): def get_value(typ):
typ = typ.find() typ = typ.find()
if isinstance(typ, TVar): if isinstance(typ, TVar):

View File

@ -303,6 +303,15 @@ class ASTTypedRewriter(algorithm.Transformer):
finally: finally:
self.env_stack.pop() self.env_stack.pop()
def visit_Call(self, node):
node = self.generic_visit(node)
node = asttyped.CallT(type=types.TVar(),
func=node.func, args=node.args, keywords=node.keywords,
starargs=node.starargs, kwargs=node.kwargs,
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc)
return node
def visit_Lambda(self, node): def visit_Lambda(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)
@ -337,7 +346,6 @@ class ASTTypedRewriter(algorithm.Transformer):
self.engine.process(diag) self.engine.process(diag)
# expr # expr
visit_Call = visit_unsupported
visit_Dict = visit_unsupported visit_Dict = visit_unsupported
visit_DictComp = visit_unsupported visit_DictComp = visit_unsupported
visit_Ellipsis = visit_unsupported visit_Ellipsis = visit_unsupported
@ -707,6 +715,79 @@ class Inferencer(algorithm.Visitor):
self.generic_visit(node) self.generic_visit(node)
self._unify_collection(element=node.target, collection=node.iter) self._unify_collection(element=node.target, collection=node.iter)
def visit_CallT(self, node):
self.generic_visit(node)
for (sigil_loc, vararg) in ((node.star_loc, node.starargs),
(node.dstar_loc, node.kwargs)):
if vararg:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported", {},
sigil_loc, [vararg.loc])
self.engine.process(diag)
return
if types.is_var(node.func.type):
return # not enough info yet
elif not types.is_function(node.func.type):
diag = diagnostic.Diagnostic("error",
"cannot call this expression of type {type}",
{"type": types.TypePrinter().name(node.func.type)},
node.func.loc, [])
self.engine.process(diag)
return
typ = node.func.type.find()
passed_args = set()
if len(node.args) > typ.arity():
note = diagnostic.Diagnostic("note",
"extraneous argument(s)", {},
node.args[typ.arity()].loc.join(node.args[-1].loc))
diag = diagnostic.Diagnostic("error",
"this function of type {type} accepts at most {num} arguments",
{"type": types.TypePrinter().name(node.func.type),
"num": typ.arity()},
node.func.loc, [], [note])
self.engine.process(diag)
return
for actualarg, (formalname, formaltyp) in \
zip(node.args, list(typ.args.items()) + list(typ.optargs.items())):
self._unify(actualarg.type, formaltyp,
actualarg.loc, None)
passed_args.add(formalname)
for keyword in node.keywords:
if keyword.arg in passed_args:
diag = diagnostic.Diagnostic("error",
"the argument '{name}' is already passed",
{"name": keyword.arg},
keyword.arg_loc)
self.engine.process(diag)
return
if keyword.arg in typ.args:
self._unify(keyword.value.type, typ.args[keyword.arg],
keyword.value.loc, None)
elif keyword.arg in typ.optargs:
self._unify(keyword.value.type, typ.optargs[keyword.arg],
keyword.value.loc, None)
passed_args.add(keyword.arg)
for formalname in typ.args:
if formalname not in passed_args:
note = diagnostic.Diagnostic("note",
"the called function is of type {type}",
{"type": types.TypePrinter().name(node.func.type)},
node.func.loc)
diag = diagnostic.Diagnostic("error",
"mandatory argument '{name}' is not passed",
{"name": formalname},
node.begin_loc.join(node.end_loc), [], [note])
self.engine.process(diag)
return
def visit_LambdaT(self, node): def visit_LambdaT(self, node):
self.generic_visit(node) self.generic_visit(node)
signature_type = self._type_from_arguments(node.args, node.body.type) signature_type = self._type_from_arguments(node.args, node.body.type)
@ -818,7 +899,7 @@ class Inferencer(algorithm.Visitor):
return OrderedDict(args) return OrderedDict(args)
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]), return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
extract_args(node.args[len(node.defaults):]), extract_args(node.args[len(node.args) - len(node.defaults):]),
ret) ret)
def visit_arguments(self, node): def visit_arguments(self, node):

View File

@ -0,0 +1,20 @@
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: ${LINE:+1}: error: cannot call this expression of type int
(1)()
def f(x, y, z=1):
pass
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
f(*[])
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
f(**[])
# CHECK-L: ${LINE:+1}: error: the argument 'x' is already passed
f(1, x=1)
# CHECK-L: ${LINE:+1}: error: mandatory argument 'x' is not passed
f()

View File

@ -0,0 +1,12 @@
# RUN: %python -m artiq.py2llvm.typing %s >%t
def _gcd(a, b):
if a < 0:
a = -a
while a:
c = a
a = b % a
b = c
return b
_gcd(10, 25)