Compare commits

..

1 Commits

Author SHA1 Message Date
1948aa38b6 nac3artiq: register special python instances 2025-02-07 12:24:56 +08:00
3 changed files with 175 additions and 110 deletions

View File

@ -16,7 +16,7 @@ __all__ = [
"rpc", "ms", "us", "ns",
"print_int32", "print_int64",
"Core", "TTLOut",
"parallel", "sequential"
"parallel", "legacy_parallel", "sequential"
]
@ -96,6 +96,7 @@ allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_functions = set()
registered_classes = set()
registered_special_ids = dict()
def register_function(fun):
assert allow_registration
@ -105,6 +106,10 @@ def register_class(cls):
assert allow_registration
registered_classes.add(cls)
def register_special(name, instance):
assert allow_registration
registered_special_ids[name] = id(instance)
def extern(function):
"""Decorates a function declaration defined by the core device runtime."""
@ -245,7 +250,7 @@ class Core:
embedding = EmbeddingMap()
if allow_registration:
compiler.analyze(registered_functions, registered_classes, set())
compiler.analyze(registered_functions, registered_classes, registered_special_ids, set())
allow_registration = False
if hasattr(method, "__self__"):
@ -336,4 +341,9 @@ class UnwrapNoneError(Exception):
artiq_builtin = True
parallel = KernelContextManager()
legacy_parallel = KernelContextManager()
sequential = KernelContextManager()
register_special("parallel", parallel)
register_special("legacy_parallel", legacy_parallel)
register_special("sequential", sequential)

View File

@ -12,7 +12,7 @@ use pyo3::{
PyObject, PyResult, Python,
};
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use super::{symbol_resolver::InnerResolver, timeline::TimeFns, SpecialPythonId};
use nac3core::{
codegen::{
expr::{destructure_range, gen_call},
@ -83,6 +83,9 @@ pub struct ArtiqCodeGenerator<'a> {
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
/// statement, which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode,
/// Specially treated python IDs to identify `with parallel` and `with sequential` blocks.
special_ids: SpecialPythonId,
}
impl<'a> ArtiqCodeGenerator<'a> {
@ -90,6 +93,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
name: String,
size_t: IntType<'_>,
timeline: &'a (dyn TimeFns + Sync),
special_ids: SpecialPythonId,
) -> ArtiqCodeGenerator<'a> {
assert!(matches!(size_t.get_bit_width(), 32 | 64));
ArtiqCodeGenerator {
@ -100,6 +104,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
end: None,
timeline,
parallel_mode: ParallelMode::None,
special_ids,
}
}
@ -109,9 +114,10 @@ impl<'a> ArtiqCodeGenerator<'a> {
ctx: &Context,
target_machine: &TargetMachine,
timeline: &'a (dyn TimeFns + Sync),
special_ids: SpecialPythonId,
) -> ArtiqCodeGenerator<'a> {
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
Self::new(name, llvm_usize, timeline)
Self::new(name, llvm_usize, timeline, special_ids)
}
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
@ -257,7 +263,22 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
// - If there is a end variable, it indicates that we are (indirectly) inside a
// parallel block, and we should update the max end value.
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
if id == &"parallel".into() || id == &"legacy_parallel".into() {
let resolver = ctx.resolver.clone();
if let Some(static_value) =
if let Some((_ptr, static_value, _counter)) = ctx.var_assignment.get(id) {
static_value.clone()
} else if let Some(ValueEnum::Static(val)) =
resolver.get_symbol_value(*id, ctx, self)
{
Some(val)
} else {
None
}
{
let python_id = static_value.get_unique_identifier();
if python_id == self.special_ids.parallel
|| python_id == self.special_ids.legacy_parallel
{
let old_start = self.start.take();
let old_end = self.end.take();
let old_parallel_mode = self.parallel_mode;
@ -307,10 +328,12 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
ctx.builder.build_store(end, now).unwrap();
self.end = Some(end_expr);
self.name_counter += 1;
self.parallel_mode = match id.to_string().as_str() {
"parallel" => ParallelMode::Deep,
"legacy_parallel" => ParallelMode::Legacy,
_ => unreachable!(),
self.parallel_mode = if python_id == self.special_ids.parallel {
ParallelMode::Deep
} else if python_id == self.special_ids.legacy_parallel {
ParallelMode::Legacy
} else {
unreachable!()
};
self.gen_block(ctx, body.iter())?;
@ -356,7 +379,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
}
return Ok(());
} else if id == &"sequential".into() {
} else if python_id == self.special_ids.sequential {
// For deep parallel, temporarily take away start to avoid function calls in
// the block from resetting the timeline.
// This does not affect legacy parallel, as the timeline will be reset after
@ -376,6 +399,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
}
}
}
}
// not parallel/sequential
gen_with(self, ctx, stmt)

View File

@ -162,6 +162,13 @@ pub struct PrimitivePythonId {
module: u64,
}
#[derive(Clone, Default)]
pub struct SpecialPythonId {
parallel: u64,
legacy_parallel: u64,
sequential: u64,
}
type TopLevelComponent = (Stmt, String, PyObject);
// TopLevelComposer is unsendable as it holds the unification table, which is
@ -179,6 +186,7 @@ struct Nac3 {
string_store: Arc<RwLock<HashMap<String, i32>>>,
exception_ids: Arc<RwLock<HashMap<usize, usize>>>,
deferred_eval_store: DeferredEvaluationStore,
special_ids: SpecialPythonId,
/// LLVM-related options for code generation.
llvm_options: CodeGenLLVMOptions,
}
@ -797,6 +805,7 @@ impl Nac3 {
&context,
&self.get_llvm_target_machine(),
self.time_fns,
self.special_ids.clone(),
))
})
.collect();
@ -813,6 +822,7 @@ impl Nac3 {
&context,
&self.get_llvm_target_machine(),
self.time_fns,
self.special_ids.clone(),
);
let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap();
@ -1192,6 +1202,7 @@ impl Nac3 {
string_store: Arc::new(string_store.into()),
exception_ids: Arc::default(),
deferred_eval_store: DeferredEvaluationStore::new(),
special_ids: Default::default(),
llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: isa.get_llvm_target_options(),
@ -1203,6 +1214,7 @@ impl Nac3 {
&mut self,
functions: &PySet,
classes: &PySet,
special_ids: &PyDict,
content_modules: &PySet,
) -> PyResult<()> {
let (modules, class_ids) =
@ -1236,6 +1248,25 @@ impl Nac3 {
for module in modules.into_values() {
self.register_module(&module, &class_ids)?;
}
self.special_ids = SpecialPythonId {
parallel: special_ids.get_item("parallel").ok().flatten().unwrap().extract().unwrap(),
legacy_parallel: special_ids
.get_item("legacy_parallel")
.ok()
.flatten()
.unwrap()
.extract()
.unwrap(),
sequential: special_ids
.get_item("sequential")
.ok()
.flatten()
.unwrap()
.extract()
.unwrap(),
};
Ok(())
}