compiler.embedding: instantiate RPC function types (#180).

This commit is contained in:
whitequark 2015-11-27 14:29:15 +08:00
parent 3b529c6f90
commit cde21bcd26
1 changed files with 22 additions and 10 deletions

View File

@ -102,6 +102,9 @@ class ASTSynthesizer:
loc = quote_loc.join(unquote_loc) loc = quote_loc.join(unquote_loc)
function_name, function_type = self.quote_function(value, self.expanded_from) function_name, function_type = self.quote_function(value, self.expanded_from)
if function_name is None:
return asttyped.QuoteT(value=value, type=function_type, loc=loc)
else:
return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc) return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc)
else: else:
quote_loc = self._add('`') quote_loc = self._add('`')
@ -499,11 +502,12 @@ class Stitcher:
# of the module with the function name. However, since we run # of the module with the function name. However, since we run
# ASTTypedRewriter on the function node directly, we need to do it # ASTTypedRewriter on the function node directly, we need to do it
# explicitly. # explicitly.
self.globals[function_node.name] = types.TVar() function_type = types.TVar()
self.globals[function_node.name] = function_type
# Memoize the function before typing it to handle recursive # Memoize the function before typing it to handle recursive
# invocations. # invocations.
self.functions[function] = function_node.name self.functions[function] = function_node.name, function_type
# Rewrite into typed form. # Rewrite into typed form.
asttyped_rewriter = StitchingASTTypedRewriter( asttyped_rewriter = StitchingASTTypedRewriter(
@ -647,21 +651,29 @@ class Stitcher:
if syscall is None: if syscall is None:
function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, function_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self.object_map.store(function)) service=self.object_map.store(function))
function_name = "rpc${}".format(function_type.service)
else: else:
function_type = types.TCFunction(arg_types, ret_type, function_type = types.TCFunction(arg_types, ret_type,
name=syscall) name=syscall)
function_name = "ffi${}".format(function_type.name)
self.globals[function_name] = function_type self.functions[function] = None, function_type
self.functions[function] = function_name
return function_name, function_type return None, function_type
def _quote_function(self, function, loc): def _quote_function(self, function, loc):
def instantiate(typ):
tvar_map = dict()
typ = typ.find()
if types.is_var(typ):
if typ not in tvar_map:
tvar_map[typ] = types.TVar()
return tvar_map[typ]
return typ
if function in self.functions: if function in self.functions:
function_name = self.functions[function] function_name, function_type = self.functions[function]
return function_name, self.globals[function_name] if types.is_rpc_function(function_type):
function_type = function_type.map(instantiate)
return function_name, function_type
if hasattr(function, "artiq_embedded"): if hasattr(function, "artiq_embedded"):
if function.artiq_embedded.function is not None: if function.artiq_embedded.function is not None: