diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 9e31dbb14..c5dfe3c08 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -323,33 +323,67 @@ class StitchingInferencer(Inferencer): else: attributes = object_type.attributes - ast = self.quote(attr_value, object_loc.expanded_from) + attr_value_type = None - def proxy_diagnostic(diag): - note = diagnostic.Diagnostic("note", - "while inferring a type for an attribute '{attr}' of a host object", - {"attr": node.attr}, - node.loc) - diag.notes.append(note) + if isinstance(attr_value, list): + # Fast path for lists of scalars. + IS_FLOAT = 1 + IS_INT32 = 2 + IS_INT64 = 4 - self.engine.process(diag) + state = 0 + for elt in attr_value: + if elt.__class__ == float: + state |= IS_FLOAT + elif elt.__class__ == int: + if -2**31 < elt < 2**31-1: + state |= IS_INT32 + elif -2**63 < elt < 2**63-1: + state |= IS_INT64 + else: + state = -1 + break + else: + state = -1 - proxy_engine = diagnostic.Engine() - proxy_engine.process = proxy_diagnostic - Inferencer(engine=proxy_engine).visit(ast) - IntMonomorphizer(engine=proxy_engine).visit(ast) + if state == IS_FLOAT: + attr_value_type = builtins.TList(builtins.TFloat()) + elif state == IS_INT32: + attr_value_type = builtins.TList(builtins.TInt32()) + elif state == IS_INT64: + attr_value_type = builtins.TList(builtins.TInt64()) + + if attr_value_type is None: + # Slow path. We don't know what exactly is the attribute value, + # so we quote it only for the error message that may possibly result. + ast = self.quote(attr_value, object_loc.expanded_from) + + def proxy_diagnostic(diag): + note = diagnostic.Diagnostic("note", + "while inferring a type for an attribute '{attr}' of a host object", + {"attr": node.attr}, + node.loc) + diag.notes.append(note) + + self.engine.process(diag) + + proxy_engine = diagnostic.Engine() + proxy_engine.process = proxy_diagnostic + Inferencer(engine=proxy_engine).visit(ast) + IntMonomorphizer(engine=proxy_engine).visit(ast) + attr_value_type = ast.type if node.attr not in attributes: # We just figured out what the type should be. Add it. - attributes[node.attr] = ast.type - elif attributes[node.attr] != ast.type and not types.is_rpc_function(ast.type): + attributes[node.attr] = attr_value_type + elif attributes[node.attr] != attr_value_type and not types.is_rpc_function(attr_value_type): # Does this conflict with an earlier guess? # RPC function types are exempt because RPCs are dynamically typed. printer = types.TypePrinter() diag = diagnostic.Diagnostic("error", "host object has an attribute '{attr}' of type {typea}, which is" " different from previously inferred type {typeb} for the same attribute", - {"typea": printer.name(ast.type), + {"typea": printer.name(attr_value_type), "typeb": printer.name(attributes[node.attr]), "attr": node.attr}, object_loc)