forked from M-Labs/artiq
compiler.embedding: add fast path for inference for lists of scalars in a host object.
This commit is contained in:
parent
237e983770
commit
3c9b53b07b
|
@ -323,6 +323,39 @@ class StitchingInferencer(Inferencer):
|
||||||
else:
|
else:
|
||||||
attributes = object_type.attributes
|
attributes = object_type.attributes
|
||||||
|
|
||||||
|
attr_value_type = None
|
||||||
|
|
||||||
|
if isinstance(attr_value, list):
|
||||||
|
# Fast path for lists of scalars.
|
||||||
|
IS_FLOAT = 1
|
||||||
|
IS_INT32 = 2
|
||||||
|
IS_INT64 = 4
|
||||||
|
|
||||||
|
state = 0
|
||||||
|
for elt in attr_value:
|
||||||
|
if elt.__class__ == float:
|
||||||
|
state |= IS_FLOAT
|
||||||
|
elif elt.__class__ == int:
|
||||||
|
if -2**31 < elt < 2**31-1:
|
||||||
|
state |= IS_INT32
|
||||||
|
elif -2**63 < elt < 2**63-1:
|
||||||
|
state |= IS_INT64
|
||||||
|
else:
|
||||||
|
state = -1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
state = -1
|
||||||
|
|
||||||
|
if state == IS_FLOAT:
|
||||||
|
attr_value_type = builtins.TList(builtins.TFloat())
|
||||||
|
elif state == IS_INT32:
|
||||||
|
attr_value_type = builtins.TList(builtins.TInt32())
|
||||||
|
elif state == IS_INT64:
|
||||||
|
attr_value_type = builtins.TList(builtins.TInt64())
|
||||||
|
|
||||||
|
if attr_value_type is None:
|
||||||
|
# Slow path. We don't know what exactly is the attribute value,
|
||||||
|
# so we quote it only for the error message that may possibly result.
|
||||||
ast = self.quote(attr_value, object_loc.expanded_from)
|
ast = self.quote(attr_value, object_loc.expanded_from)
|
||||||
|
|
||||||
def proxy_diagnostic(diag):
|
def proxy_diagnostic(diag):
|
||||||
|
@ -338,18 +371,19 @@ class StitchingInferencer(Inferencer):
|
||||||
proxy_engine.process = proxy_diagnostic
|
proxy_engine.process = proxy_diagnostic
|
||||||
Inferencer(engine=proxy_engine).visit(ast)
|
Inferencer(engine=proxy_engine).visit(ast)
|
||||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||||
|
attr_value_type = ast.type
|
||||||
|
|
||||||
if node.attr not in attributes:
|
if node.attr not in attributes:
|
||||||
# We just figured out what the type should be. Add it.
|
# We just figured out what the type should be. Add it.
|
||||||
attributes[node.attr] = ast.type
|
attributes[node.attr] = attr_value_type
|
||||||
elif attributes[node.attr] != ast.type and not types.is_rpc_function(ast.type):
|
elif attributes[node.attr] != attr_value_type and not types.is_rpc_function(attr_value_type):
|
||||||
# Does this conflict with an earlier guess?
|
# Does this conflict with an earlier guess?
|
||||||
# RPC function types are exempt because RPCs are dynamically typed.
|
# RPC function types are exempt because RPCs are dynamically typed.
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"host object has an attribute '{attr}' of type {typea}, which is"
|
"host object has an attribute '{attr}' of type {typea}, which is"
|
||||||
" different from previously inferred type {typeb} for the same attribute",
|
" different from previously inferred type {typeb} for the same attribute",
|
||||||
{"typea": printer.name(ast.type),
|
{"typea": printer.name(attr_value_type),
|
||||||
"typeb": printer.name(attributes[node.attr]),
|
"typeb": printer.name(attributes[node.attr]),
|
||||||
"attr": node.attr},
|
"attr": node.attr},
|
||||||
object_loc)
|
object_loc)
|
||||||
|
|
Loading…
Reference in New Issue