diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index ebac15754..1b930f4f8 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -102,7 +102,10 @@ class ASTSynthesizer: loc = quote_loc.join(unquote_loc) function_name, function_type = self.quote_function(value, self.expanded_from) - return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc) + if function_name is None: + return asttyped.QuoteT(value=value, type=function_type, loc=loc) + else: + return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc) else: quote_loc = self._add('`') repr_loc = self._add(repr(value)) @@ -499,11 +502,12 @@ class Stitcher: # of the module with the function name. However, since we run # ASTTypedRewriter on the function node directly, we need to do it # explicitly. - self.globals[function_node.name] = types.TVar() + function_type = types.TVar() + self.globals[function_node.name] = function_type # Memoize the function before typing it to handle recursive # invocations. - self.functions[function] = function_node.name + self.functions[function] = function_node.name, function_type # Rewrite into typed form. asttyped_rewriter = StitchingASTTypedRewriter( @@ -647,21 +651,29 @@ class Stitcher: if syscall is None: function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, service=self.object_map.store(function)) - function_name = "rpc${}".format(function_type.service) else: function_type = types.TCFunction(arg_types, ret_type, name=syscall) - function_name = "ffi${}".format(function_type.name) - self.globals[function_name] = function_type - self.functions[function] = function_name + self.functions[function] = None, function_type - return function_name, function_type + return None, function_type def _quote_function(self, function, loc): + def instantiate(typ): + tvar_map = dict() + typ = typ.find() + if types.is_var(typ): + if typ not in tvar_map: + tvar_map[typ] = types.TVar() + return tvar_map[typ] + return typ + if function in self.functions: - function_name = self.functions[function] - return function_name, self.globals[function_name] + function_name, function_type = self.functions[function] + if types.is_rpc_function(function_type): + function_type = function_type.map(instantiate) + return function_name, function_type if hasattr(function, "artiq_embedded"): if function.artiq_embedded.function is not None: