transforms.inferencer: factor out _unify_attribute.

This commit is contained in:
whitequark 2016-01-04 22:13:05 +08:00
parent 03dd1c3a43
commit 5baf18ba0d
3 changed files with 50 additions and 46 deletions

View File

@ -286,10 +286,7 @@ class StitchingInferencer(Inferencer):
self.value_map = value_map self.value_map = value_map
self.quote = quote self.quote = quote
def visit_AttributeT(self, node): def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
self.generic_visit(node)
object_type = node.value.type.find()
# The inferencer can only observe types, not values; however, # The inferencer can only observe types, not values; however,
# when we work with host objects, we have to get the values # when we work with host objects, we have to get the values
# somewhere, since host interpreter does not have types. # somewhere, since host interpreter does not have types.
@ -304,28 +301,31 @@ class StitchingInferencer(Inferencer):
# * a previously unknown attribute is encountered, # * a previously unknown attribute is encountered,
# * a previously unknown host object is encountered; # * a previously unknown host object is encountered;
# which would be the optimal solution. # which would be the optimal solution.
object_type = value_node.type.find()
attr_value_type = None
for object_value, object_loc in self.value_map[object_type]: for object_value, object_loc in self.value_map[object_type]:
if not hasattr(object_value, node.attr): if not hasattr(object_value, attr_name):
if node.attr.startswith('_'): if attr_name.startswith('_'):
names = set(filter(lambda name: not name.startswith('_'), names = set(filter(lambda name: not name.startswith('_'),
dir(object_value))) dir(object_value)))
else: else:
names = set(dir(object_value)) names = set(dir(object_value))
suggestion = suggest_identifier(node.attr, names) suggestion = suggest_identifier(attr_name, names)
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"attribute accessed here", {}, "attribute accessed here", {},
node.loc) loc)
if suggestion is not None: if suggestion is not None:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object does not have an attribute '{attr}'; " "host object does not have an attribute '{attr}'; "
"did you mean '{suggestion}'?", "did you mean '{suggestion}'?",
{"attr": node.attr, "suggestion": suggestion}, {"attr": attr_name, "suggestion": suggestion},
object_loc, notes=[note]) object_loc, notes=[note])
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object does not have an attribute '{attr}'", "host object does not have an attribute '{attr}'",
{"attr": node.attr}, {"attr": attr_name},
object_loc, notes=[note]) object_loc, notes=[note])
self.engine.process(diag) self.engine.process(diag)
return return
@ -335,7 +335,7 @@ class StitchingInferencer(Inferencer):
# overhead (i.e. synthesizing a source buffer), but has the advantage # overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and # of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors. # also immediately getting proper diagnostics on type errors.
attr_value = getattr(object_value, node.attr) attr_value = getattr(object_value, attr_name)
if inspect.ismethod(attr_value) and types.is_instance(object_type): if inspect.ismethod(attr_value) and types.is_instance(object_type):
# In cases like: # In cases like:
# class c: # class c:
@ -349,8 +349,6 @@ class StitchingInferencer(Inferencer):
attributes = object_type.attributes attributes = object_type.attributes
is_method = False is_method = False
attr_value_type = None
if isinstance(attr_value, list): if isinstance(attr_value, list):
# Fast path for lists of scalars. # Fast path for lists of scalars.
IS_FLOAT = 1 IS_FLOAT = 1
@ -387,8 +385,8 @@ class StitchingInferencer(Inferencer):
def proxy_diagnostic(diag): def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"while inferring a type for an attribute '{attr}' of a host object", "while inferring a type for an attribute '{attr}' of a host object",
{"attr": node.attr}, {"attr": attr_name},
node.loc) loc)
diag.notes.append(note) diag.notes.append(note)
self.engine.process(diag) self.engine.process(diag)
@ -399,31 +397,26 @@ class StitchingInferencer(Inferencer):
IntMonomorphizer(engine=proxy_engine).visit(ast) IntMonomorphizer(engine=proxy_engine).visit(ast)
attr_value_type = ast.type attr_value_type = ast.type
if is_method and types.is_rpc_function(attr_value_type): if attr_name not in attributes:
self_type = list(attr_value_type.args.values())[0]
self._unify(object_type, self_type,
node.loc, None)
if node.attr not in attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
attributes[node.attr] = attr_value_type attributes[attr_name] = attr_value_type
elif not types.is_rpc_function(attr_value_type): elif not types.is_rpc_function(attr_value_type):
# Does this conflict with an earlier guess? # Does this conflict with an earlier guess?
# RPC function types are exempt because RPCs are dynamically typed. # RPC function types are exempt because RPCs are dynamically typed.
try: try:
attributes[node.attr].unify(attr_value_type) attributes[attr_name].unify(attr_value_type)
except types.UnificationError as e: except types.UnificationError as e:
printer = types.TypePrinter() printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object has an attribute '{attr}' of type {typea}, which is" "host object has an attribute '{attr}' of type {typea}, which is"
" different from previously inferred type {typeb} for the same attribute", " different from previously inferred type {typeb} for the same attribute",
{"typea": printer.name(attr_value_type), {"typea": printer.name(attr_value_type),
"typeb": printer.name(attributes[node.attr]), "typeb": printer.name(attributes[attr_name]),
"attr": node.attr}, "attr": node.attr},
object_loc) object_loc)
self.engine.process(diag) self.engine.process(diag)
super().visit_AttributeT(node) super()._unify_attribute(result_type, value_node, attr_name, attr_loc, loc)
class TypedtreeHasher(algorithm.Visitor): class TypedtreeHasher(algorithm.Visitor):
def generic_visit(self, node): def generic_visit(self, node):

View File

@ -88,9 +88,14 @@ class Inferencer(algorithm.Visitor):
def visit_AttributeT(self, node): def visit_AttributeT(self, node):
self.generic_visit(node) self.generic_visit(node)
object_type = node.value.type.find() self._unify_attribute(result_type=node.type, value_node=node.value,
attr_name=node.attr, attr_loc=node.attr_loc,
loc=node.loc)
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
object_type = value_node.type.find()
if not types.is_var(object_type): if not types.is_var(object_type):
if node.attr in object_type.attributes: if attr_name in object_type.attributes:
def makenotes(printer, typea, typeb, loca, locb): def makenotes(printer, typea, typeb, loca, locb):
return [ return [
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
@ -100,18 +105,18 @@ class Inferencer(algorithm.Visitor):
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"expression of type {typeb}", "expression of type {typeb}",
{"typeb": printer.name(object_type)}, {"typeb": printer.name(object_type)},
node.value.loc) value_node.loc)
] ]
attr_type = object_type.attributes[node.attr] attr_type = object_type.attributes[attr_name]
if types.is_rpc_function(attr_type): if types.is_rpc_function(attr_type):
attr_type = types.instantiate(attr_type) attr_type = types.instantiate(attr_type)
self._unify(node.type, attr_type, node.loc, None, self._unify(result_type, attr_type, loc, None,
makenotes=makenotes, when=" for attribute '{}'".format(node.attr)) makenotes=makenotes, when=" for attribute '{}'".format(attr_name))
elif types.is_instance(object_type) and \ elif types.is_instance(object_type) and \
node.attr in object_type.constructor.attributes: attr_name in object_type.constructor.attributes:
attr_type = object_type.constructor.attributes[node.attr].find() attr_type = object_type.constructor.attributes[attr_name].find()
if types.is_rpc_function(attr_type): if types.is_rpc_function(attr_type):
attr_type = types.instantiate(attr_type) attr_type = types.instantiate(attr_type)
@ -120,48 +125,54 @@ class Inferencer(algorithm.Visitor):
if len(attr_type.args) < 1: if len(attr_type.args) < 1:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"function '{attr}{type}' of class '{class}' cannot accept a self argument", "function '{attr}{type}' of class '{class}' cannot accept a self argument",
{"attr": node.attr, "type": types.TypePrinter().name(attr_type), {"attr": attr_name, "type": types.TypePrinter().name(attr_type),
"class": object_type.name}, "class": object_type.name},
node.loc) loc)
self.engine.process(diag) self.engine.process(diag)
return return
else: else:
def makenotes(printer, typea, typeb, loca, locb): def makenotes(printer, typea, typeb, loca, locb):
if attr_loc is None:
msgb = "reference to an instance with a method '{attr}{typeb}'"
else:
msgb = "reference to a method '{attr}{typeb}'"
return [ return [
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"expression of type {typea}", "expression of type {typea}",
{"typea": printer.name(typea)}, {"typea": printer.name(typea)},
loca), loca),
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"reference to a class function of type {typeb}", msgb,
{"typeb": printer.name(attr_type)}, {"attr": attr_name,
"typeb": printer.name(attr_type)},
locb) locb)
] ]
self._unify(object_type, list(attr_type.args.values())[0], self._unify(object_type, list(attr_type.args.values())[0],
node.value.loc, node.loc, value_node.loc, loc,
makenotes=makenotes, makenotes=makenotes,
when=" while inferring the type for self argument") when=" while inferring the type for self argument")
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, attr_type)
if not types.is_var(attr_type): if not types.is_var(attr_type):
self._unify(node.type, attr_type, self._unify(result_type, attr_type,
node.loc, None) loc, None)
else: else:
if node.attr_loc.source_buffer == node.value.loc.source_buffer: if attr_name_loc.source_buffer == value_node.loc.source_buffer:
highlights, notes = [node.value.loc], [] highlights, notes = [value_node.loc], []
else: else:
# This happens when the object being accessed is embedded # This happens when the object being accessed is embedded
# from the host program. # from the host program.
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"object being accessed", {}, "object being accessed", {},
node.value.loc) value_node.loc)
highlights, notes = [], [note] highlights, notes = [], [note]
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'", "type {type} does not have an attribute '{attr}'",
{"type": types.TypePrinter().name(object_type), "attr": node.attr}, {"type": types.TypePrinter().name(object_type), "attr": attr_name},
node.attr_loc, highlights, notes) node.attr_loc, highlights, notes)
self.engine.process(diag) self.engine.process(diag)

View File

@ -727,9 +727,9 @@ class TypePrinter(object):
signature += " " + self.name(delay) signature += " " + self.name(delay)
if isinstance(typ, TRPCFunction): if isinstance(typ, TRPCFunction):
return "rpc({}) {}".format(typ.service, signature) return "[rpc #{}]{}".format(typ.service, signature)
if isinstance(typ, TCFunction): if isinstance(typ, TCFunction):
return "ffi({}) {}".format(repr(typ.name), signature) return "[ffi {}]{}".format(repr(typ.name), signature)
elif isinstance(typ, TFunction): elif isinstance(typ, TFunction):
return signature return signature
elif isinstance(typ, TBuiltinFunction): elif isinstance(typ, TBuiltinFunction):