forked from M-Labs/artiq
compiler.embedding: cache attribute types (fixes #276).
This commit is contained in:
parent
d899d7307e
commit
f838b8be49
@ -295,6 +295,102 @@ class StitchingInferencer(Inferencer):
|
||||
super().__init__(engine)
|
||||
self.value_map = value_map
|
||||
self.quote = quote
|
||||
self.attr_type_cache = {}
|
||||
|
||||
def _compute_value_type(self, object_value, object_type, object_loc, attr_name, loc):
|
||||
if not hasattr(object_value, attr_name):
|
||||
if attr_name.startswith('_'):
|
||||
names = set(filter(lambda name: not name.startswith('_'),
|
||||
dir(object_value)))
|
||||
else:
|
||||
names = set(dir(object_value))
|
||||
suggestion = suggest_identifier(attr_name, names)
|
||||
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"attribute accessed here", {},
|
||||
loc)
|
||||
if suggestion is not None:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"host object does not have an attribute '{attr}'; "
|
||||
"did you mean '{suggestion}'?",
|
||||
{"attr": attr_name, "suggestion": suggestion},
|
||||
object_loc, notes=[note])
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"host object does not have an attribute '{attr}'",
|
||||
{"attr": attr_name},
|
||||
object_loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
# Figure out what ARTIQ type does the value of the attribute have.
|
||||
# We do this by quoting it, as if to serialize. This has some
|
||||
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
||||
# of having the host-to-ARTIQ mapping code in only one place and
|
||||
# also immediately getting proper diagnostics on type errors.
|
||||
attr_value = getattr(object_value, attr_name)
|
||||
if inspect.ismethod(attr_value) and types.is_instance(object_type):
|
||||
# In cases like:
|
||||
# class c:
|
||||
# @kernel
|
||||
# def f(self): pass
|
||||
# we want f to be defined on the class, not on the instance.
|
||||
attributes = object_type.constructor.attributes
|
||||
attr_value = attr_value.__func__
|
||||
else:
|
||||
attributes = object_type.attributes
|
||||
|
||||
attr_value_type = None
|
||||
|
||||
if isinstance(attr_value, list):
|
||||
# Fast path for lists of scalars.
|
||||
IS_FLOAT = 1
|
||||
IS_INT32 = 2
|
||||
IS_INT64 = 4
|
||||
|
||||
state = 0
|
||||
for elt in attr_value:
|
||||
if elt.__class__ == float:
|
||||
state |= IS_FLOAT
|
||||
elif elt.__class__ == int:
|
||||
if -2**31 < elt < 2**31-1:
|
||||
state |= IS_INT32
|
||||
elif -2**63 < elt < 2**63-1:
|
||||
state |= IS_INT64
|
||||
else:
|
||||
state = -1
|
||||
break
|
||||
else:
|
||||
state = -1
|
||||
|
||||
if state == IS_FLOAT:
|
||||
attr_value_type = builtins.TList(builtins.TFloat())
|
||||
elif state == IS_INT32:
|
||||
attr_value_type = builtins.TList(builtins.TInt32())
|
||||
elif state == IS_INT64:
|
||||
attr_value_type = builtins.TList(builtins.TInt64())
|
||||
|
||||
if attr_value_type is None:
|
||||
# Slow path. We don't know what exactly is the attribute value,
|
||||
# so we quote it only for the error message that may possibly result.
|
||||
ast = self.quote(attr_value, object_loc.expanded_from)
|
||||
|
||||
def proxy_diagnostic(diag):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"while inferring a type for an attribute '{attr}' of a host object",
|
||||
{"attr": attr_name},
|
||||
loc)
|
||||
diag.notes.append(note)
|
||||
|
||||
self.engine.process(diag)
|
||||
|
||||
proxy_engine = diagnostic.Engine()
|
||||
proxy_engine.process = proxy_diagnostic
|
||||
Inferencer(engine=proxy_engine).visit(ast)
|
||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||
attr_value_type = ast.type
|
||||
|
||||
return attributes, attr_value_type
|
||||
|
||||
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
|
||||
# The inferencer can only observe types, not values; however,
|
||||
@ -304,108 +400,15 @@ class StitchingInferencer(Inferencer):
|
||||
# its type, we now interrogate every host object we have to ensure
|
||||
# that we can successfully serialize the value of the attribute we
|
||||
# are now adding at the code generation stage.
|
||||
#
|
||||
# FIXME: We perform exhaustive checks of every known host object every
|
||||
# time an attribute access is visited, which is potentially quadratic.
|
||||
# This is done because it is simpler than performing the checks only when:
|
||||
# * a previously unknown attribute is encountered,
|
||||
# * a previously unknown host object is encountered;
|
||||
# which would be the optimal solution.
|
||||
|
||||
object_type = value_node.type.find()
|
||||
for object_value, object_loc in self.value_map[object_type]:
|
||||
attr_value_type = None
|
||||
if not hasattr(object_value, attr_name):
|
||||
if attr_name.startswith('_'):
|
||||
names = set(filter(lambda name: not name.startswith('_'),
|
||||
dir(object_value)))
|
||||
else:
|
||||
names = set(dir(object_value))
|
||||
suggestion = suggest_identifier(attr_name, names)
|
||||
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"attribute accessed here", {},
|
||||
loc)
|
||||
if suggestion is not None:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"host object does not have an attribute '{attr}'; "
|
||||
"did you mean '{suggestion}'?",
|
||||
{"attr": attr_name, "suggestion": suggestion},
|
||||
object_loc, notes=[note])
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"host object does not have an attribute '{attr}'",
|
||||
{"attr": attr_name},
|
||||
object_loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
# Figure out what ARTIQ type does the value of the attribute have.
|
||||
# We do this by quoting it, as if to serialize. This has some
|
||||
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
||||
# of having the host-to-ARTIQ mapping code in only one place and
|
||||
# also immediately getting proper diagnostics on type errors.
|
||||
attr_value = getattr(object_value, attr_name)
|
||||
if inspect.ismethod(attr_value) and types.is_instance(object_type):
|
||||
# In cases like:
|
||||
# class c:
|
||||
# @kernel
|
||||
# def f(self): pass
|
||||
# we want f to be defined on the class, not on the instance.
|
||||
attributes = object_type.constructor.attributes
|
||||
attr_value = attr_value.__func__
|
||||
is_method = True
|
||||
else:
|
||||
attributes = object_type.attributes
|
||||
is_method = False
|
||||
|
||||
if isinstance(attr_value, list):
|
||||
# Fast path for lists of scalars.
|
||||
IS_FLOAT = 1
|
||||
IS_INT32 = 2
|
||||
IS_INT64 = 4
|
||||
|
||||
state = 0
|
||||
for elt in attr_value:
|
||||
if elt.__class__ == float:
|
||||
state |= IS_FLOAT
|
||||
elif elt.__class__ == int:
|
||||
if -2**31 < elt < 2**31-1:
|
||||
state |= IS_INT32
|
||||
elif -2**63 < elt < 2**63-1:
|
||||
state |= IS_INT64
|
||||
else:
|
||||
state = -1
|
||||
break
|
||||
else:
|
||||
state = -1
|
||||
|
||||
if state == IS_FLOAT:
|
||||
attr_value_type = builtins.TList(builtins.TFloat())
|
||||
elif state == IS_INT32:
|
||||
attr_value_type = builtins.TList(builtins.TInt32())
|
||||
elif state == IS_INT64:
|
||||
attr_value_type = builtins.TList(builtins.TInt64())
|
||||
|
||||
if attr_value_type is None:
|
||||
# Slow path. We don't know what exactly is the attribute value,
|
||||
# so we quote it only for the error message that may possibly result.
|
||||
ast = self.quote(attr_value, object_loc.expanded_from)
|
||||
|
||||
def proxy_diagnostic(diag):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"while inferring a type for an attribute '{attr}' of a host object",
|
||||
{"attr": attr_name},
|
||||
loc)
|
||||
diag.notes.append(note)
|
||||
|
||||
self.engine.process(diag)
|
||||
|
||||
proxy_engine = diagnostic.Engine()
|
||||
proxy_engine.process = proxy_diagnostic
|
||||
Inferencer(engine=proxy_engine).visit(ast)
|
||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||
attr_value_type = ast.type
|
||||
attr_type_key = (id(object_value), attr_name)
|
||||
try:
|
||||
attributes, attr_value_type = self.attr_type_cache[attr_type_key]
|
||||
except KeyError:
|
||||
attributes, attr_value_type = \
|
||||
self._compute_value_type(object_value, object_type, object_loc, attr_name, loc)
|
||||
self.attr_type_cache[attr_type_key] = attributes, attr_value_type
|
||||
|
||||
if attr_name not in attributes:
|
||||
# We just figured out what the type should be. Add it.
|
||||
|
Loading…
Reference in New Issue
Block a user