From f838b8be497ac68206565d0f0582873b7136ccbf Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 25 Feb 2016 19:56:45 +0000 Subject: [PATCH] compiler.embedding: cache attribute types (fixes #276). --- artiq/compiler/embedding.py | 203 ++++++++++++++++++------------------ 1 file changed, 103 insertions(+), 100 deletions(-) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 7cd53eaaa..17155d767 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -295,6 +295,102 @@ class StitchingInferencer(Inferencer): super().__init__(engine) self.value_map = value_map self.quote = quote + self.attr_type_cache = {} + + def _compute_value_type(self, object_value, object_type, object_loc, attr_name, loc): + if not hasattr(object_value, attr_name): + if attr_name.startswith('_'): + names = set(filter(lambda name: not name.startswith('_'), + dir(object_value))) + else: + names = set(dir(object_value)) + suggestion = suggest_identifier(attr_name, names) + + note = diagnostic.Diagnostic("note", + "attribute accessed here", {}, + loc) + if suggestion is not None: + diag = diagnostic.Diagnostic("error", + "host object does not have an attribute '{attr}'; " + "did you mean '{suggestion}'?", + {"attr": attr_name, "suggestion": suggestion}, + object_loc, notes=[note]) + else: + diag = diagnostic.Diagnostic("error", + "host object does not have an attribute '{attr}'", + {"attr": attr_name}, + object_loc, notes=[note]) + self.engine.process(diag) + return + + # Figure out what ARTIQ type does the value of the attribute have. + # We do this by quoting it, as if to serialize. This has some + # overhead (i.e. synthesizing a source buffer), but has the advantage + # of having the host-to-ARTIQ mapping code in only one place and + # also immediately getting proper diagnostics on type errors. + attr_value = getattr(object_value, attr_name) + if inspect.ismethod(attr_value) and types.is_instance(object_type): + # In cases like: + # class c: + # @kernel + # def f(self): pass + # we want f to be defined on the class, not on the instance. + attributes = object_type.constructor.attributes + attr_value = attr_value.__func__ + else: + attributes = object_type.attributes + + attr_value_type = None + + if isinstance(attr_value, list): + # Fast path for lists of scalars. + IS_FLOAT = 1 + IS_INT32 = 2 + IS_INT64 = 4 + + state = 0 + for elt in attr_value: + if elt.__class__ == float: + state |= IS_FLOAT + elif elt.__class__ == int: + if -2**31 < elt < 2**31-1: + state |= IS_INT32 + elif -2**63 < elt < 2**63-1: + state |= IS_INT64 + else: + state = -1 + break + else: + state = -1 + + if state == IS_FLOAT: + attr_value_type = builtins.TList(builtins.TFloat()) + elif state == IS_INT32: + attr_value_type = builtins.TList(builtins.TInt32()) + elif state == IS_INT64: + attr_value_type = builtins.TList(builtins.TInt64()) + + if attr_value_type is None: + # Slow path. We don't know what exactly is the attribute value, + # so we quote it only for the error message that may possibly result. + ast = self.quote(attr_value, object_loc.expanded_from) + + def proxy_diagnostic(diag): + note = diagnostic.Diagnostic("note", + "while inferring a type for an attribute '{attr}' of a host object", + {"attr": attr_name}, + loc) + diag.notes.append(note) + + self.engine.process(diag) + + proxy_engine = diagnostic.Engine() + proxy_engine.process = proxy_diagnostic + Inferencer(engine=proxy_engine).visit(ast) + IntMonomorphizer(engine=proxy_engine).visit(ast) + attr_value_type = ast.type + + return attributes, attr_value_type def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc): # The inferencer can only observe types, not values; however, @@ -304,108 +400,15 @@ class StitchingInferencer(Inferencer): # its type, we now interrogate every host object we have to ensure # that we can successfully serialize the value of the attribute we # are now adding at the code generation stage. - # - # FIXME: We perform exhaustive checks of every known host object every - # time an attribute access is visited, which is potentially quadratic. - # This is done because it is simpler than performing the checks only when: - # * a previously unknown attribute is encountered, - # * a previously unknown host object is encountered; - # which would be the optimal solution. - object_type = value_node.type.find() for object_value, object_loc in self.value_map[object_type]: - attr_value_type = None - if not hasattr(object_value, attr_name): - if attr_name.startswith('_'): - names = set(filter(lambda name: not name.startswith('_'), - dir(object_value))) - else: - names = set(dir(object_value)) - suggestion = suggest_identifier(attr_name, names) - - note = diagnostic.Diagnostic("note", - "attribute accessed here", {}, - loc) - if suggestion is not None: - diag = diagnostic.Diagnostic("error", - "host object does not have an attribute '{attr}'; " - "did you mean '{suggestion}'?", - {"attr": attr_name, "suggestion": suggestion}, - object_loc, notes=[note]) - else: - diag = diagnostic.Diagnostic("error", - "host object does not have an attribute '{attr}'", - {"attr": attr_name}, - object_loc, notes=[note]) - self.engine.process(diag) - return - - # Figure out what ARTIQ type does the value of the attribute have. - # We do this by quoting it, as if to serialize. This has some - # overhead (i.e. synthesizing a source buffer), but has the advantage - # of having the host-to-ARTIQ mapping code in only one place and - # also immediately getting proper diagnostics on type errors. - attr_value = getattr(object_value, attr_name) - if inspect.ismethod(attr_value) and types.is_instance(object_type): - # In cases like: - # class c: - # @kernel - # def f(self): pass - # we want f to be defined on the class, not on the instance. - attributes = object_type.constructor.attributes - attr_value = attr_value.__func__ - is_method = True - else: - attributes = object_type.attributes - is_method = False - - if isinstance(attr_value, list): - # Fast path for lists of scalars. - IS_FLOAT = 1 - IS_INT32 = 2 - IS_INT64 = 4 - - state = 0 - for elt in attr_value: - if elt.__class__ == float: - state |= IS_FLOAT - elif elt.__class__ == int: - if -2**31 < elt < 2**31-1: - state |= IS_INT32 - elif -2**63 < elt < 2**63-1: - state |= IS_INT64 - else: - state = -1 - break - else: - state = -1 - - if state == IS_FLOAT: - attr_value_type = builtins.TList(builtins.TFloat()) - elif state == IS_INT32: - attr_value_type = builtins.TList(builtins.TInt32()) - elif state == IS_INT64: - attr_value_type = builtins.TList(builtins.TInt64()) - - if attr_value_type is None: - # Slow path. We don't know what exactly is the attribute value, - # so we quote it only for the error message that may possibly result. - ast = self.quote(attr_value, object_loc.expanded_from) - - def proxy_diagnostic(diag): - note = diagnostic.Diagnostic("note", - "while inferring a type for an attribute '{attr}' of a host object", - {"attr": attr_name}, - loc) - diag.notes.append(note) - - self.engine.process(diag) - - proxy_engine = diagnostic.Engine() - proxy_engine.process = proxy_diagnostic - Inferencer(engine=proxy_engine).visit(ast) - IntMonomorphizer(engine=proxy_engine).visit(ast) - attr_value_type = ast.type + attr_type_key = (id(object_value), attr_name) + try: + attributes, attr_value_type = self.attr_type_cache[attr_type_key] + except KeyError: + attributes, attr_value_type = \ + self._compute_value_type(object_value, object_type, object_loc, attr_name, loc) + self.attr_type_cache[attr_type_key] = attributes, attr_value_type if attr_name not in attributes: # We just figured out what the type should be. Add it.