forked from M-Labs/artiq
1
0
Fork 0

compiler.embedding: add fast path for inference for lists of scalars in a host object.

This commit is contained in:
whitequark 2015-11-27 19:06:04 +08:00
parent 237e983770
commit 3c9b53b07b
1 changed files with 49 additions and 15 deletions

View File

@ -323,33 +323,67 @@ class StitchingInferencer(Inferencer):
else: else:
attributes = object_type.attributes attributes = object_type.attributes
ast = self.quote(attr_value, object_loc.expanded_from) attr_value_type = None
def proxy_diagnostic(diag): if isinstance(attr_value, list):
note = diagnostic.Diagnostic("note", # Fast path for lists of scalars.
"while inferring a type for an attribute '{attr}' of a host object", IS_FLOAT = 1
{"attr": node.attr}, IS_INT32 = 2
node.loc) IS_INT64 = 4
diag.notes.append(note)
self.engine.process(diag) state = 0
for elt in attr_value:
if elt.__class__ == float:
state |= IS_FLOAT
elif elt.__class__ == int:
if -2**31 < elt < 2**31-1:
state |= IS_INT32
elif -2**63 < elt < 2**63-1:
state |= IS_INT64
else:
state = -1
break
else:
state = -1
proxy_engine = diagnostic.Engine() if state == IS_FLOAT:
proxy_engine.process = proxy_diagnostic attr_value_type = builtins.TList(builtins.TFloat())
Inferencer(engine=proxy_engine).visit(ast) elif state == IS_INT32:
IntMonomorphizer(engine=proxy_engine).visit(ast) attr_value_type = builtins.TList(builtins.TInt32())
elif state == IS_INT64:
attr_value_type = builtins.TList(builtins.TInt64())
if attr_value_type is None:
# Slow path. We don't know what exactly is the attribute value,
# so we quote it only for the error message that may possibly result.
ast = self.quote(attr_value, object_loc.expanded_from)
def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
"while inferring a type for an attribute '{attr}' of a host object",
{"attr": node.attr},
node.loc)
diag.notes.append(note)
self.engine.process(diag)
proxy_engine = diagnostic.Engine()
proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)
attr_value_type = ast.type
if node.attr not in attributes: if node.attr not in attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
attributes[node.attr] = ast.type attributes[node.attr] = attr_value_type
elif attributes[node.attr] != ast.type and not types.is_rpc_function(ast.type): elif attributes[node.attr] != attr_value_type and not types.is_rpc_function(attr_value_type):
# Does this conflict with an earlier guess? # Does this conflict with an earlier guess?
# RPC function types are exempt because RPCs are dynamically typed. # RPC function types are exempt because RPCs are dynamically typed.
printer = types.TypePrinter() printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object has an attribute '{attr}' of type {typea}, which is" "host object has an attribute '{attr}' of type {typea}, which is"
" different from previously inferred type {typeb} for the same attribute", " different from previously inferred type {typeb} for the same attribute",
{"typea": printer.name(ast.type), {"typea": printer.name(attr_value_type),
"typeb": printer.name(attributes[node.attr]), "typeb": printer.name(attributes[node.attr]),
"attr": node.attr}, "attr": node.attr},
object_loc) object_loc)