From 16ae0fb6eb6a765b17055c27342e1967f6876af4 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 27 Nov 2015 16:29:13 +0800 Subject: [PATCH] compiler.embedding: instantiate RPC method types (fixes #180). --- artiq/compiler/embedding.py | 14 +++----------- artiq/compiler/transforms/inferencer.py | 11 ++++++----- artiq/compiler/types.py | 12 ++++++++++++ artiq/coredevice/comm_dummy.py | 9 +++++++++ lit-test/test/embedding/function_polymorphism.py | 12 ++++++++++++ lit-test/test/embedding/method_polymorphism.py | 14 ++++++++++++++ 6 files changed, 56 insertions(+), 16 deletions(-) create mode 100644 lit-test/test/embedding/function_polymorphism.py create mode 100644 lit-test/test/embedding/method_polymorphism.py diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index f64cb7ef0..9e31dbb14 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -342,8 +342,9 @@ class StitchingInferencer(Inferencer): if node.attr not in attributes: # We just figured out what the type should be. Add it. attributes[node.attr] = ast.type - elif attributes[node.attr] != ast.type: + elif attributes[node.attr] != ast.type and not types.is_rpc_function(ast.type): # Does this conflict with an earlier guess? + # RPC function types are exempt because RPCs are dynamically typed. printer = types.TypePrinter() diag = diagnostic.Diagnostic("error", "host object has an attribute '{attr}' of type {typea}, which is" @@ -660,15 +661,6 @@ class Stitcher: return None, function_type def _quote_function(self, function, loc): - def instantiate(typ): - tvar_map = dict() - typ = typ.find() - if types.is_var(typ): - if typ not in tvar_map: - tvar_map[typ] = types.TVar() - return tvar_map[typ] - return typ - if function in self.functions: result = self.functions[function] else: @@ -694,7 +686,7 @@ class Stitcher: function_name, function_type = result if types.is_rpc_function(function_type): - function_type = function_type.map(instantiate) + function_type = types.instantiate(function_type) return function_name, function_type def _quote(self, value, loc): diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index f823014f8..24f3d94b2 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -103,13 +103,14 @@ class Inferencer(algorithm.Visitor): node.value.loc) ] - # Assumes no free type variables in .attributes. - self._unify(node.type, object_type.attributes[node.attr], - node.loc, None, + attr_type = object_type.attributes[node.attr] + if types.is_function(attr_type): + attr_type = types.instantiate(attr_type) + + self._unify(node.type, attr_type, node.loc, None, makenotes=makenotes, when=" for attribute '{}'".format(node.attr)) elif types.is_instance(object_type) and \ node.attr in object_type.constructor.attributes: - # Assumes no free type variables in .attributes. attr_type = object_type.constructor.attributes[node.attr].find() if types.is_function(attr_type): # Convert to a method. @@ -139,7 +140,7 @@ class Inferencer(algorithm.Visitor): makenotes=makenotes, when=" while inferring the type for self argument") - attr_type = types.TMethod(object_type, attr_type) + attr_type = types.TMethod(object_type, types.instantiate(attr_type)) if not types.is_var(attr_type): self._unify(node.type, attr_type, diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index b6f911a32..08d9fcd2c 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -539,6 +539,18 @@ def TFixedDelay(duration): return TDelay(duration, None) +def instantiate(typ): + tvar_map = dict() + def mapper(typ): + typ = typ.find() + if is_var(typ): + if typ not in tvar_map: + tvar_map[typ] = TVar() + return tvar_map[typ] + return typ + + return typ.map(mapper) + def is_var(typ): return isinstance(typ.find(), TVar) diff --git a/artiq/coredevice/comm_dummy.py b/artiq/coredevice/comm_dummy.py index 82a1a9575..c8626bb17 100644 --- a/artiq/coredevice/comm_dummy.py +++ b/artiq/coredevice/comm_dummy.py @@ -16,3 +16,12 @@ class Comm: def serve(self, object_map, symbolizer): pass + + def check_ident(self): + pass + + def get_log(self): + return "" + + def clear_log(self): + pass diff --git a/lit-test/test/embedding/function_polymorphism.py b/lit-test/test/embedding/function_polymorphism.py new file mode 100644 index 000000000..52df5ee4f --- /dev/null +++ b/lit-test/test/embedding/function_polymorphism.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s + +from artiq.language.core import * +from artiq.language.types import * + +def f(x): + print(x) + +@kernel +def entrypoint(): + f("foo") + f(42) diff --git a/lit-test/test/embedding/method_polymorphism.py b/lit-test/test/embedding/method_polymorphism.py new file mode 100644 index 000000000..b7ca9f525 --- /dev/null +++ b/lit-test/test/embedding/method_polymorphism.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s + +from artiq.language.core import * +from artiq.language.types import * + +class c: + def p(self, foo): + print(foo) +i = c() + +@kernel +def entrypoint(): + i.p("foo") + i.p(42)