forked from M-Labs/artiq
Implement Call.
This commit is contained in:
parent
7a00a4a47f
commit
8c5e58f83c
|
@ -152,6 +152,9 @@ class TFunction(Type):
|
||||||
def __init__(self, args, optargs, ret):
|
def __init__(self, args, optargs, ret):
|
||||||
self.args, self.optargs, self.ret = args, optargs, ret
|
self.args, self.optargs, self.ret = args, optargs, ret
|
||||||
|
|
||||||
|
def arity(self):
|
||||||
|
return len(self.args) + len(self.optargs)
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -229,6 +232,9 @@ def is_tuple(typ, elts=None):
|
||||||
else:
|
else:
|
||||||
return isinstance(typ, TTuple)
|
return isinstance(typ, TTuple)
|
||||||
|
|
||||||
|
def is_function(typ):
|
||||||
|
return isinstance(typ.find(), TFunction)
|
||||||
|
|
||||||
def get_value(typ):
|
def get_value(typ):
|
||||||
typ = typ.find()
|
typ = typ.find()
|
||||||
if isinstance(typ, TVar):
|
if isinstance(typ, TVar):
|
||||||
|
|
|
@ -303,6 +303,15 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
finally:
|
finally:
|
||||||
self.env_stack.pop()
|
self.env_stack.pop()
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
node = asttyped.CallT(type=types.TVar(),
|
||||||
|
func=node.func, args=node.args, keywords=node.keywords,
|
||||||
|
starargs=node.starargs, kwargs=node.kwargs,
|
||||||
|
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
|
||||||
|
begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc)
|
||||||
|
return node
|
||||||
|
|
||||||
def visit_Lambda(self, node):
|
def visit_Lambda(self, node):
|
||||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
extractor.visit(node)
|
extractor.visit(node)
|
||||||
|
@ -337,7 +346,6 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
# expr
|
# expr
|
||||||
visit_Call = visit_unsupported
|
|
||||||
visit_Dict = visit_unsupported
|
visit_Dict = visit_unsupported
|
||||||
visit_DictComp = visit_unsupported
|
visit_DictComp = visit_unsupported
|
||||||
visit_Ellipsis = visit_unsupported
|
visit_Ellipsis = visit_unsupported
|
||||||
|
@ -707,6 +715,79 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self._unify_collection(element=node.target, collection=node.iter)
|
self._unify_collection(element=node.target, collection=node.iter)
|
||||||
|
|
||||||
|
def visit_CallT(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
for (sigil_loc, vararg) in ((node.star_loc, node.starargs),
|
||||||
|
(node.dstar_loc, node.kwargs)):
|
||||||
|
if vararg:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"variadic arguments are not supported", {},
|
||||||
|
sigil_loc, [vararg.loc])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
if types.is_var(node.func.type):
|
||||||
|
return # not enough info yet
|
||||||
|
elif not types.is_function(node.func.type):
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"cannot call this expression of type {type}",
|
||||||
|
{"type": types.TypePrinter().name(node.func.type)},
|
||||||
|
node.func.loc, [])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
typ = node.func.type.find()
|
||||||
|
passed_args = set()
|
||||||
|
|
||||||
|
if len(node.args) > typ.arity():
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"extraneous argument(s)", {},
|
||||||
|
node.args[typ.arity()].loc.join(node.args[-1].loc))
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"this function of type {type} accepts at most {num} arguments",
|
||||||
|
{"type": types.TypePrinter().name(node.func.type),
|
||||||
|
"num": typ.arity()},
|
||||||
|
node.func.loc, [], [note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
for actualarg, (formalname, formaltyp) in \
|
||||||
|
zip(node.args, list(typ.args.items()) + list(typ.optargs.items())):
|
||||||
|
self._unify(actualarg.type, formaltyp,
|
||||||
|
actualarg.loc, None)
|
||||||
|
passed_args.add(formalname)
|
||||||
|
|
||||||
|
for keyword in node.keywords:
|
||||||
|
if keyword.arg in passed_args:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"the argument '{name}' is already passed",
|
||||||
|
{"name": keyword.arg},
|
||||||
|
keyword.arg_loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
if keyword.arg in typ.args:
|
||||||
|
self._unify(keyword.value.type, typ.args[keyword.arg],
|
||||||
|
keyword.value.loc, None)
|
||||||
|
elif keyword.arg in typ.optargs:
|
||||||
|
self._unify(keyword.value.type, typ.optargs[keyword.arg],
|
||||||
|
keyword.value.loc, None)
|
||||||
|
passed_args.add(keyword.arg)
|
||||||
|
|
||||||
|
for formalname in typ.args:
|
||||||
|
if formalname not in passed_args:
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"the called function is of type {type}",
|
||||||
|
{"type": types.TypePrinter().name(node.func.type)},
|
||||||
|
node.func.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"mandatory argument '{name}' is not passed",
|
||||||
|
{"name": formalname},
|
||||||
|
node.begin_loc.join(node.end_loc), [], [note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
def visit_LambdaT(self, node):
|
def visit_LambdaT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
signature_type = self._type_from_arguments(node.args, node.body.type)
|
signature_type = self._type_from_arguments(node.args, node.body.type)
|
||||||
|
@ -818,7 +899,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
return OrderedDict(args)
|
return OrderedDict(args)
|
||||||
|
|
||||||
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
|
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
|
||||||
extract_args(node.args[len(node.defaults):]),
|
extract_args(node.args[len(node.args) - len(node.defaults):]),
|
||||||
ret)
|
ret)
|
||||||
|
|
||||||
def visit_arguments(self, node):
|
def visit_arguments(self, node):
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: cannot call this expression of type int
|
||||||
|
(1)()
|
||||||
|
|
||||||
|
def f(x, y, z=1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
|
||||||
|
f(*[])
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
|
||||||
|
f(**[])
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: the argument 'x' is already passed
|
||||||
|
f(1, x=1)
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: mandatory argument 'x' is not passed
|
||||||
|
f()
|
|
@ -0,0 +1,12 @@
|
||||||
|
# RUN: %python -m artiq.py2llvm.typing %s >%t
|
||||||
|
|
||||||
|
def _gcd(a, b):
|
||||||
|
if a < 0:
|
||||||
|
a = -a
|
||||||
|
while a:
|
||||||
|
c = a
|
||||||
|
a = b % a
|
||||||
|
b = c
|
||||||
|
return b
|
||||||
|
|
||||||
|
_gcd(10, 25)
|
Loading…
Reference in New Issue