1
0
forked from M-Labs/nac3

Cleaned up implementation

This commit is contained in:
ram 2025-01-28 06:14:07 +00:00
parent 83fa1db6b8
commit 9f7dbecae2
2 changed files with 55 additions and 29 deletions

View File

@ -4,6 +4,7 @@ from types import SimpleNamespace
from numpy import int32, int64 from numpy import int32, int64
from typing import Generic, TypeVar from typing import Generic, TypeVar
from math import floor, ceil from math import floor, ceil
import builtins
import nac3artiq import nac3artiq
from embedding_map import EmbeddingMap from embedding_map import EmbeddingMap
@ -17,7 +18,8 @@ __all__ = [
"rpc", "ms", "us", "ns", "rpc", "ms", "us", "ns",
"print_int32", "print_int64", "print_int32", "print_int64",
"Core", "TTLOut", "Core", "TTLOut",
"parallel", "sequential" "parallel", "sequential",
"StringWrapper"
] ]
@ -91,6 +93,7 @@ artiq_builtins = {
"virtual": virtual, "virtual": virtual,
"_ConstGenericMarker": _ConstGenericMarker, "_ConstGenericMarker": _ConstGenericMarker,
"Option": Option, "Option": Option,
"str": builtins.str,
} }
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins) compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
allow_registration = True allow_registration = True
@ -204,6 +207,8 @@ class Core:
global allow_registration global allow_registration
embedding = EmbeddingMap() embedding = EmbeddingMap()
for value, str_id in sorted(string_store.items(), key=lambda x: x[1]):
embedding.string_map[value] = str_id
if allow_registration: if allow_registration:
compiler.analyze(registered_functions, registered_classes, set()) compiler.analyze(registered_functions, registered_classes, set())
@ -291,6 +296,36 @@ class KernelContextManager:
def __exit__(self): def __exit__(self):
pass pass
@nac3
class StringWrapper:
"""Wrapper for Python strings in NAC3"""
artiq_builtin = True
_value: str
_id: int
def __init__(self, value: str):
global next_string_id
self._value = value
if value not in string_store:
string_store[value] = next_string_id
next_string_id += 1
self._id = string_store[value]
def __str__(self):
return self._value
def get_identifier(self) -> int:
return self._id
string_store = {}
NAC3_INTERNAL_STRINGS = {
"0:artiq.coredevice.exceptions.RTIOUnderflow": 0,
"": 1,
}
for s, id in NAC3_INTERNAL_STRINGS.items():
string_store[s] = id
next_string_id = max(NAC3_INTERNAL_STRINGS.values()) + 1
@nac3 @nac3
class UnwrapNoneError(Exception): class UnwrapNoneError(Exception):
"""raised when unwrapping a none value""" """raised when unwrapping a none value"""

View File

@ -379,16 +379,16 @@ impl Nac3 {
link_fn: &dyn Fn(&Module) -> PyResult<T>, link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> { ) -> PyResult<T> {
Python::with_gil(|_py| { Python::with_gil(|_py| {
let mut string_store = self.string_store.write(); let string_map_py = embedding_map.getattr("string_map")?;
for arg in &args { let reverse_map_py = embedding_map.getattr("string_reverse_map")?;
if let Ok(s) = arg.extract::<String>() {
if !string_store.contains_key(&s) { let string_store = self.string_store.read();
let next_id = i32::try_from(string_store.len()).unwrap(); for (s, key) in string_store.iter() {
string_store.insert(s.clone(), next_id); string_map_py.set_item(s, key)?;
} reverse_map_py.set_item(key, s)?;
}
} }
}); Ok::<_, PyErr>(())
})?;
let size_t = self.isa.get_size_type(); let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
@ -565,10 +565,16 @@ impl Nac3 {
name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?); name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?);
let mut arg_names = vec![]; let mut arg_names = vec![];
for (i, arg) in args.into_iter().enumerate() { for (i, arg) in args.into_iter().enumerate() {
let name = format!("tmp{i}"); if let Ok(st) = arg.extract::<String>() {
module.add(&name, arg)?; let literal = format!("{st:?}");
name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?); arg_names.push(literal);
arg_names.push(name); } else {
let tmp_name = format!("tmp{i}");
module.add(&tmp_name, arg)?;
let pyid_val: u64 = id_fun.call1((arg,))?.extract()?;
name_to_pyid.insert(tmp_name.clone().into(), pyid_val);
arg_names.push(tmp_name);
}
} }
let synthesized = if method_name.is_empty() { let synthesized = if method_name.is_empty() {
format!("def __modinit__():\n base({})", arg_names.join(", ")) format!("def __modinit__():\n base({})", arg_names.join(", "))
@ -834,20 +840,6 @@ impl Nac3 {
panic!("Failed to run optimization for module `main`: {}", err.to_string()); panic!("Failed to run optimization for module `main`: {}", err.to_string());
} }
Python::with_gil(|py| {
let string_store = self.string_store.read();
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
string_store_vec.sort_by(|(_s1, key1), (_s2, key2)| key1.cmp(key2));
for (s, key) in string_store_vec {
let embed_key: i32 = helper.store_str.call1(py, (s,)).unwrap().extract(py).unwrap();
assert_eq!(
embed_key, *key,
"string {s} is out of sync between embedding map (key={embed_key}) and \
the internal string store (key={key})"
);
}
});
link_fn(&main) link_fn(&main)
} }
@ -1107,7 +1099,6 @@ impl Nac3 {
fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap(); fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
let mut string_store: HashMap<String, i32> = HashMap::default(); let mut string_store: HashMap<String, i32> = HashMap::default();
string_store.insert(String::new(), 0);
// Keep this list of exceptions in sync with `EXCEPTION_ID_LOOKUP` in `artiq::firmware::ksupport::eh_artiq` // Keep this list of exceptions in sync with `EXCEPTION_ID_LOOKUP` in `artiq::firmware::ksupport::eh_artiq`
// The exceptions declared here must be defined in `artiq.coredevice.exceptions` // The exceptions declared here must be defined in `artiq.coredevice.exceptions`