forked from M-Labs/artiq
1
0
Fork 0

compiler.embedding: instantiate RPC method types (fixes #180).

This commit is contained in:
whitequark 2015-11-27 16:29:13 +08:00
parent 0a794fe7e4
commit 16ae0fb6eb
6 changed files with 56 additions and 16 deletions

View File

@ -342,8 +342,9 @@ class StitchingInferencer(Inferencer):
if node.attr not in attributes: if node.attr not in attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
attributes[node.attr] = ast.type attributes[node.attr] = ast.type
elif attributes[node.attr] != ast.type: elif attributes[node.attr] != ast.type and not types.is_rpc_function(ast.type):
# Does this conflict with an earlier guess? # Does this conflict with an earlier guess?
# RPC function types are exempt because RPCs are dynamically typed.
printer = types.TypePrinter() printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object has an attribute '{attr}' of type {typea}, which is" "host object has an attribute '{attr}' of type {typea}, which is"
@ -660,15 +661,6 @@ class Stitcher:
return None, function_type return None, function_type
def _quote_function(self, function, loc): def _quote_function(self, function, loc):
def instantiate(typ):
tvar_map = dict()
typ = typ.find()
if types.is_var(typ):
if typ not in tvar_map:
tvar_map[typ] = types.TVar()
return tvar_map[typ]
return typ
if function in self.functions: if function in self.functions:
result = self.functions[function] result = self.functions[function]
else: else:
@ -694,7 +686,7 @@ class Stitcher:
function_name, function_type = result function_name, function_type = result
if types.is_rpc_function(function_type): if types.is_rpc_function(function_type):
function_type = function_type.map(instantiate) function_type = types.instantiate(function_type)
return function_name, function_type return function_name, function_type
def _quote(self, value, loc): def _quote(self, value, loc):

View File

@ -103,13 +103,14 @@ class Inferencer(algorithm.Visitor):
node.value.loc) node.value.loc)
] ]
# Assumes no free type variables in .attributes. attr_type = object_type.attributes[node.attr]
self._unify(node.type, object_type.attributes[node.attr], if types.is_function(attr_type):
node.loc, None, attr_type = types.instantiate(attr_type)
self._unify(node.type, attr_type, node.loc, None,
makenotes=makenotes, when=" for attribute '{}'".format(node.attr)) makenotes=makenotes, when=" for attribute '{}'".format(node.attr))
elif types.is_instance(object_type) and \ elif types.is_instance(object_type) and \
node.attr in object_type.constructor.attributes: node.attr in object_type.constructor.attributes:
# Assumes no free type variables in .attributes.
attr_type = object_type.constructor.attributes[node.attr].find() attr_type = object_type.constructor.attributes[node.attr].find()
if types.is_function(attr_type): if types.is_function(attr_type):
# Convert to a method. # Convert to a method.
@ -139,7 +140,7 @@ class Inferencer(algorithm.Visitor):
makenotes=makenotes, makenotes=makenotes,
when=" while inferring the type for self argument") when=" while inferring the type for self argument")
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, types.instantiate(attr_type))
if not types.is_var(attr_type): if not types.is_var(attr_type):
self._unify(node.type, attr_type, self._unify(node.type, attr_type,

View File

@ -539,6 +539,18 @@ def TFixedDelay(duration):
return TDelay(duration, None) return TDelay(duration, None)
def instantiate(typ):
tvar_map = dict()
def mapper(typ):
typ = typ.find()
if is_var(typ):
if typ not in tvar_map:
tvar_map[typ] = TVar()
return tvar_map[typ]
return typ
return typ.map(mapper)
def is_var(typ): def is_var(typ):
return isinstance(typ.find(), TVar) return isinstance(typ.find(), TVar)

View File

@ -16,3 +16,12 @@ class Comm:
def serve(self, object_map, symbolizer): def serve(self, object_map, symbolizer):
pass pass
def check_ident(self):
pass
def get_log(self):
return ""
def clear_log(self):
pass

View File

@ -0,0 +1,12 @@
# RUN: %python -m artiq.compiler.testbench.embedding %s
from artiq.language.core import *
from artiq.language.types import *
def f(x):
print(x)
@kernel
def entrypoint():
f("foo")
f(42)

View File

@ -0,0 +1,14 @@
# RUN: %python -m artiq.compiler.testbench.embedding %s
from artiq.language.core import *
from artiq.language.types import *
class c:
def p(self, foo):
print(foo)
i = c()
@kernel
def entrypoint():
i.p("foo")
i.p(42)