From 81d2b37b571f373267a85ddb417f2d0278da36da Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Tue, 26 Sep 2023 23:31:21 +0100 Subject: [PATCH] compiler: Fix crash on multiple types with the same name The original fix in 21574bdfa91223f2aedf6f7d3a568f450f750516 was incomplete, as it only addressed the TInstance types, but not their linked (typ.constructor) TConstructor instances. This would (potentially among other issues) cause assertion errors in llvm_ir_generator due to the wrong associated globals being referenced; see added test case for an example that previously caused such a crash. Also modified the name collision detection from O(len(type_map)) (so quadratic overall in the number of custom types) to cache names in sets for O(1) lookup. --- artiq/compiler/embedding.py | 46 ++++++++++++------- artiq/test/lit/embedding/class_same_name.py | 51 +++++++++++++++++++++ 2 files changed, 81 insertions(+), 16 deletions(-) create mode 100644 artiq/test/lit/embedding/class_same_name.py diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index eb3cc3a8f..ad23e01cc 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -45,7 +45,14 @@ class EmbeddingMap: self.object_forward_map = {} self.object_reverse_map = {} self.module_map = {} + + # type_map connects the host Python `type` to the pair of associated + # `(TInstance, TConstructor)`s. The `used_…_names` sets cache the + # respective `.name`s for O(1) collision avoidance. self.type_map = {} + self.used_instance_type_names = set() + self.used_constructor_type_names = set() + self.function_map = {} # Modules @@ -60,16 +67,6 @@ class EmbeddingMap: # Types def store_type(self, host_type, instance_type, constructor_type): - self._rename_type(instance_type) - self.type_map[host_type] = (instance_type, constructor_type) - - def retrieve_type(self, host_type): - return self.type_map[host_type] - - def has_type(self, host_type): - return host_type in self.type_map - - def _rename_type(self, new_instance_type): # Generally, user-defined types that have exact same name (which is to say, classes # defined inside functions) do not pose a problem to the compiler. The two places which # cannot handle this are: @@ -78,12 +75,29 @@ class EmbeddingMap: # Since handling #2 requires renaming on ARTIQ side anyway, it's more straightforward # to do it once when embedding (since non-embedded code cannot define classes in # functions). Also, easier to debug. - n = 0 - for host_type in self.type_map: - instance_type, constructor_type = self.type_map[host_type] - if instance_type.name == new_instance_type.name: - n += 1 - new_instance_type.name = "{}.{}".format(new_instance_type.name, n) + suffix = 0 + new_instance_name = instance_type.name + new_constructor_name = constructor_type.name + while True: + if (new_instance_name not in self.used_instance_type_names + and new_constructor_name not in self.used_constructor_type_names): + break + suffix += 1 + new_instance_name = f"{instance_type.name}.{suffix}" + new_constructor_name = f"{constructor_type.name}.{suffix}" + + self.used_instance_type_names.add(new_instance_name) + instance_type.name = new_instance_name + self.used_constructor_type_names.add(new_constructor_name) + constructor_type.name = new_constructor_name + + self.type_map[host_type] = (instance_type, constructor_type) + + def retrieve_type(self, host_type): + return self.type_map[host_type] + + def has_type(self, host_type): + return host_type in self.type_map def attribute_count(self): count = 0 diff --git a/artiq/test/lit/embedding/class_same_name.py b/artiq/test/lit/embedding/class_same_name.py new file mode 100644 index 000000000..46cf7c16d --- /dev/null +++ b/artiq/test/lit/embedding/class_same_name.py @@ -0,0 +1,51 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s + +from artiq.language.core import * + + +class InnerA: + def __init__(self, val): + self.val = val + + @kernel + def run_once(self): + return self.val + + +class InnerB: + def __init__(self, val): + self.val = val + + @kernel + def run_once(self): + return self.val + + +def make_runner(InnerCls, val): + class Runner: + def __init__(self): + self.inner = InnerCls(val) + + @kernel + def run_once(self): + return self.inner.run_once() + + return Runner() + + +class Parent: + def __init__(self): + self.a = make_runner(InnerA, 1) + self.b = make_runner(InnerB, 42.0) + + @kernel + def run_once(self): + return self.a.run_once() + self.b.run_once() + + +parent = Parent() + + +@kernel +def entrypoint(): + parent.run_once()