forked from M-Labs/artiq
compiler.embedding: cache attribute types (fixes #276).
This commit is contained in:
parent
d899d7307e
commit
f838b8be49
|
@ -295,6 +295,102 @@ class StitchingInferencer(Inferencer):
|
||||||
super().__init__(engine)
|
super().__init__(engine)
|
||||||
self.value_map = value_map
|
self.value_map = value_map
|
||||||
self.quote = quote
|
self.quote = quote
|
||||||
|
self.attr_type_cache = {}
|
||||||
|
|
||||||
|
def _compute_value_type(self, object_value, object_type, object_loc, attr_name, loc):
|
||||||
|
if not hasattr(object_value, attr_name):
|
||||||
|
if attr_name.startswith('_'):
|
||||||
|
names = set(filter(lambda name: not name.startswith('_'),
|
||||||
|
dir(object_value)))
|
||||||
|
else:
|
||||||
|
names = set(dir(object_value))
|
||||||
|
suggestion = suggest_identifier(attr_name, names)
|
||||||
|
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"attribute accessed here", {},
|
||||||
|
loc)
|
||||||
|
if suggestion is not None:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"host object does not have an attribute '{attr}'; "
|
||||||
|
"did you mean '{suggestion}'?",
|
||||||
|
{"attr": attr_name, "suggestion": suggestion},
|
||||||
|
object_loc, notes=[note])
|
||||||
|
else:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"host object does not have an attribute '{attr}'",
|
||||||
|
{"attr": attr_name},
|
||||||
|
object_loc, notes=[note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Figure out what ARTIQ type does the value of the attribute have.
|
||||||
|
# We do this by quoting it, as if to serialize. This has some
|
||||||
|
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
||||||
|
# of having the host-to-ARTIQ mapping code in only one place and
|
||||||
|
# also immediately getting proper diagnostics on type errors.
|
||||||
|
attr_value = getattr(object_value, attr_name)
|
||||||
|
if inspect.ismethod(attr_value) and types.is_instance(object_type):
|
||||||
|
# In cases like:
|
||||||
|
# class c:
|
||||||
|
# @kernel
|
||||||
|
# def f(self): pass
|
||||||
|
# we want f to be defined on the class, not on the instance.
|
||||||
|
attributes = object_type.constructor.attributes
|
||||||
|
attr_value = attr_value.__func__
|
||||||
|
else:
|
||||||
|
attributes = object_type.attributes
|
||||||
|
|
||||||
|
attr_value_type = None
|
||||||
|
|
||||||
|
if isinstance(attr_value, list):
|
||||||
|
# Fast path for lists of scalars.
|
||||||
|
IS_FLOAT = 1
|
||||||
|
IS_INT32 = 2
|
||||||
|
IS_INT64 = 4
|
||||||
|
|
||||||
|
state = 0
|
||||||
|
for elt in attr_value:
|
||||||
|
if elt.__class__ == float:
|
||||||
|
state |= IS_FLOAT
|
||||||
|
elif elt.__class__ == int:
|
||||||
|
if -2**31 < elt < 2**31-1:
|
||||||
|
state |= IS_INT32
|
||||||
|
elif -2**63 < elt < 2**63-1:
|
||||||
|
state |= IS_INT64
|
||||||
|
else:
|
||||||
|
state = -1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
state = -1
|
||||||
|
|
||||||
|
if state == IS_FLOAT:
|
||||||
|
attr_value_type = builtins.TList(builtins.TFloat())
|
||||||
|
elif state == IS_INT32:
|
||||||
|
attr_value_type = builtins.TList(builtins.TInt32())
|
||||||
|
elif state == IS_INT64:
|
||||||
|
attr_value_type = builtins.TList(builtins.TInt64())
|
||||||
|
|
||||||
|
if attr_value_type is None:
|
||||||
|
# Slow path. We don't know what exactly is the attribute value,
|
||||||
|
# so we quote it only for the error message that may possibly result.
|
||||||
|
ast = self.quote(attr_value, object_loc.expanded_from)
|
||||||
|
|
||||||
|
def proxy_diagnostic(diag):
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"while inferring a type for an attribute '{attr}' of a host object",
|
||||||
|
{"attr": attr_name},
|
||||||
|
loc)
|
||||||
|
diag.notes.append(note)
|
||||||
|
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
proxy_engine = diagnostic.Engine()
|
||||||
|
proxy_engine.process = proxy_diagnostic
|
||||||
|
Inferencer(engine=proxy_engine).visit(ast)
|
||||||
|
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||||
|
attr_value_type = ast.type
|
||||||
|
|
||||||
|
return attributes, attr_value_type
|
||||||
|
|
||||||
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
|
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
|
||||||
# The inferencer can only observe types, not values; however,
|
# The inferencer can only observe types, not values; however,
|
||||||
|
@ -304,108 +400,15 @@ class StitchingInferencer(Inferencer):
|
||||||
# its type, we now interrogate every host object we have to ensure
|
# its type, we now interrogate every host object we have to ensure
|
||||||
# that we can successfully serialize the value of the attribute we
|
# that we can successfully serialize the value of the attribute we
|
||||||
# are now adding at the code generation stage.
|
# are now adding at the code generation stage.
|
||||||
#
|
|
||||||
# FIXME: We perform exhaustive checks of every known host object every
|
|
||||||
# time an attribute access is visited, which is potentially quadratic.
|
|
||||||
# This is done because it is simpler than performing the checks only when:
|
|
||||||
# * a previously unknown attribute is encountered,
|
|
||||||
# * a previously unknown host object is encountered;
|
|
||||||
# which would be the optimal solution.
|
|
||||||
|
|
||||||
object_type = value_node.type.find()
|
object_type = value_node.type.find()
|
||||||
for object_value, object_loc in self.value_map[object_type]:
|
for object_value, object_loc in self.value_map[object_type]:
|
||||||
attr_value_type = None
|
attr_type_key = (id(object_value), attr_name)
|
||||||
if not hasattr(object_value, attr_name):
|
try:
|
||||||
if attr_name.startswith('_'):
|
attributes, attr_value_type = self.attr_type_cache[attr_type_key]
|
||||||
names = set(filter(lambda name: not name.startswith('_'),
|
except KeyError:
|
||||||
dir(object_value)))
|
attributes, attr_value_type = \
|
||||||
else:
|
self._compute_value_type(object_value, object_type, object_loc, attr_name, loc)
|
||||||
names = set(dir(object_value))
|
self.attr_type_cache[attr_type_key] = attributes, attr_value_type
|
||||||
suggestion = suggest_identifier(attr_name, names)
|
|
||||||
|
|
||||||
note = diagnostic.Diagnostic("note",
|
|
||||||
"attribute accessed here", {},
|
|
||||||
loc)
|
|
||||||
if suggestion is not None:
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"host object does not have an attribute '{attr}'; "
|
|
||||||
"did you mean '{suggestion}'?",
|
|
||||||
{"attr": attr_name, "suggestion": suggestion},
|
|
||||||
object_loc, notes=[note])
|
|
||||||
else:
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"host object does not have an attribute '{attr}'",
|
|
||||||
{"attr": attr_name},
|
|
||||||
object_loc, notes=[note])
|
|
||||||
self.engine.process(diag)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Figure out what ARTIQ type does the value of the attribute have.
|
|
||||||
# We do this by quoting it, as if to serialize. This has some
|
|
||||||
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
|
||||||
# of having the host-to-ARTIQ mapping code in only one place and
|
|
||||||
# also immediately getting proper diagnostics on type errors.
|
|
||||||
attr_value = getattr(object_value, attr_name)
|
|
||||||
if inspect.ismethod(attr_value) and types.is_instance(object_type):
|
|
||||||
# In cases like:
|
|
||||||
# class c:
|
|
||||||
# @kernel
|
|
||||||
# def f(self): pass
|
|
||||||
# we want f to be defined on the class, not on the instance.
|
|
||||||
attributes = object_type.constructor.attributes
|
|
||||||
attr_value = attr_value.__func__
|
|
||||||
is_method = True
|
|
||||||
else:
|
|
||||||
attributes = object_type.attributes
|
|
||||||
is_method = False
|
|
||||||
|
|
||||||
if isinstance(attr_value, list):
|
|
||||||
# Fast path for lists of scalars.
|
|
||||||
IS_FLOAT = 1
|
|
||||||
IS_INT32 = 2
|
|
||||||
IS_INT64 = 4
|
|
||||||
|
|
||||||
state = 0
|
|
||||||
for elt in attr_value:
|
|
||||||
if elt.__class__ == float:
|
|
||||||
state |= IS_FLOAT
|
|
||||||
elif elt.__class__ == int:
|
|
||||||
if -2**31 < elt < 2**31-1:
|
|
||||||
state |= IS_INT32
|
|
||||||
elif -2**63 < elt < 2**63-1:
|
|
||||||
state |= IS_INT64
|
|
||||||
else:
|
|
||||||
state = -1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
state = -1
|
|
||||||
|
|
||||||
if state == IS_FLOAT:
|
|
||||||
attr_value_type = builtins.TList(builtins.TFloat())
|
|
||||||
elif state == IS_INT32:
|
|
||||||
attr_value_type = builtins.TList(builtins.TInt32())
|
|
||||||
elif state == IS_INT64:
|
|
||||||
attr_value_type = builtins.TList(builtins.TInt64())
|
|
||||||
|
|
||||||
if attr_value_type is None:
|
|
||||||
# Slow path. We don't know what exactly is the attribute value,
|
|
||||||
# so we quote it only for the error message that may possibly result.
|
|
||||||
ast = self.quote(attr_value, object_loc.expanded_from)
|
|
||||||
|
|
||||||
def proxy_diagnostic(diag):
|
|
||||||
note = diagnostic.Diagnostic("note",
|
|
||||||
"while inferring a type for an attribute '{attr}' of a host object",
|
|
||||||
{"attr": attr_name},
|
|
||||||
loc)
|
|
||||||
diag.notes.append(note)
|
|
||||||
|
|
||||||
self.engine.process(diag)
|
|
||||||
|
|
||||||
proxy_engine = diagnostic.Engine()
|
|
||||||
proxy_engine.process = proxy_diagnostic
|
|
||||||
Inferencer(engine=proxy_engine).visit(ast)
|
|
||||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
|
||||||
attr_value_type = ast.type
|
|
||||||
|
|
||||||
if attr_name not in attributes:
|
if attr_name not in attributes:
|
||||||
# We just figured out what the type should be. Add it.
|
# We just figured out what the type should be. Add it.
|
||||||
|
|
Loading…
Reference in New Issue