From 1948aa38b6166b8ebb73a8c795f18b92277082fe Mon Sep 17 00:00:00 2001 From: occheung Date: Fri, 7 Feb 2025 12:22:37 +0800 Subject: [PATCH] nac3artiq: register special python instances --- nac3artiq/demo/min_artiq.py | 14 ++- nac3artiq/src/codegen.rs | 240 ++++++++++++++++++++---------------- nac3artiq/src/lib.rs | 31 +++++ 3 files changed, 175 insertions(+), 110 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index fef018b2..cfceed95 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -16,7 +16,7 @@ __all__ = [ "rpc", "ms", "us", "ns", "print_int32", "print_int64", "Core", "TTLOut", - "parallel", "sequential" + "parallel", "legacy_parallel", "sequential" ] @@ -96,6 +96,7 @@ allow_registration = True # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. registered_functions = set() registered_classes = set() +registered_special_ids = dict() def register_function(fun): assert allow_registration @@ -105,6 +106,10 @@ def register_class(cls): assert allow_registration registered_classes.add(cls) +def register_special(name, instance): + assert allow_registration + registered_special_ids[name] = id(instance) + def extern(function): """Decorates a function declaration defined by the core device runtime.""" @@ -245,7 +250,7 @@ class Core: embedding = EmbeddingMap() if allow_registration: - compiler.analyze(registered_functions, registered_classes, set()) + compiler.analyze(registered_functions, registered_classes, registered_special_ids, set()) allow_registration = False if hasattr(method, "__self__"): @@ -336,4 +341,9 @@ class UnwrapNoneError(Exception): artiq_builtin = True parallel = KernelContextManager() +legacy_parallel = KernelContextManager() sequential = KernelContextManager() + +register_special("parallel", parallel) +register_special("legacy_parallel", legacy_parallel) +register_special("sequential", sequential) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 572acccf..61c77246 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -12,7 +12,7 @@ use pyo3::{ PyObject, PyResult, Python, }; -use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use super::{symbol_resolver::InnerResolver, timeline::TimeFns, SpecialPythonId}; use nac3core::{ codegen::{ expr::{destructure_range, gen_call}, @@ -83,6 +83,9 @@ pub struct ArtiqCodeGenerator<'a> { /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// statement, which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, + + /// Specially treated python IDs to identify `with parallel` and `with sequential` blocks. + special_ids: SpecialPythonId, } impl<'a> ArtiqCodeGenerator<'a> { @@ -90,6 +93,7 @@ impl<'a> ArtiqCodeGenerator<'a> { name: String, size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { @@ -100,6 +104,7 @@ impl<'a> ArtiqCodeGenerator<'a> { end: None, timeline, parallel_mode: ParallelMode::None, + special_ids, } } @@ -109,9 +114,10 @@ impl<'a> ArtiqCodeGenerator<'a> { ctx: &Context, target_machine: &TargetMachine, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); - Self::new(name, llvm_usize, timeline) + Self::new(name, llvm_usize, timeline, special_ids) } /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the @@ -257,122 +263,140 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> { // - If there is a end variable, it indicates that we are (indirectly) inside a // parallel block, and we should update the max end value. if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { - if id == &"parallel".into() || id == &"legacy_parallel".into() { - let old_start = self.start.take(); - let old_end = self.end.take(); - let old_parallel_mode = self.parallel_mode; + let resolver = ctx.resolver.clone(); + if let Some(static_value) = + if let Some((_ptr, static_value, _counter)) = ctx.var_assignment.get(id) { + static_value.clone() + } else if let Some(ValueEnum::Static(val)) = + resolver.get_symbol_value(*id, ctx, self) + { + Some(val) + } else { + None + } + { + let python_id = static_value.get_unique_identifier(); + if python_id == self.special_ids.parallel + || python_id == self.special_ids.legacy_parallel + { + let old_start = self.start.take(); + let old_end = self.end.take(); + let old_parallel_mode = self.parallel_mode; - let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + let now = if let Some(old_start) = &old_start { + self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + ctx, + self, + old_start.custom.unwrap(), + )? + } else { + self.timeline.emit_now_mu(ctx) + }; + + // Emulate variable allocation, as we need to use the CodeGenContext + // HashMap to store our variable due to lifetime limitation + // Note: we should be able to store variables directly if generic + // associative type is used by limiting the lifetime of CodeGenerator to + // the LLVM Context. + // The name is guaranteed to be unique as users cannot use this as variable + // name. + self.start = old_start.clone().map_or_else( + || { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: start, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let start = self + .gen_store_target(ctx, &start_expr, Some("start.addr"))? + .unwrap(); + ctx.builder.build_store(start, now).unwrap(); + Ok(Some(start_expr)) as Result<_, String> + }, + |v| Ok(Some(v)), + )?; + let end = format!("with-{}-end", self.name_counter).into(); + let end_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: end, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); + ctx.builder.build_store(end, now).unwrap(); + self.end = Some(end_expr); + self.name_counter += 1; + self.parallel_mode = if python_id == self.special_ids.parallel { + ParallelMode::Deep + } else if python_id == self.special_ids.legacy_parallel { + ParallelMode::Legacy + } else { + unreachable!() + }; + + self.gen_block(ctx, body.iter())?; + + let current = ctx.builder.get_insert_block().unwrap(); + + // if the current block is terminated, move before the terminator + // we want to set the timeline before reaching the terminator + // TODO: This may be unsound if there are multiple exit paths in the + // block... e.g. + // if ...: + // return + // Perhaps we can fix this by using actual with block? + let reset_position = if let Some(terminator) = current.get_terminator() { + ctx.builder.position_before(&terminator); + true + } else { + false + }; + + // set duration + let end_expr = self.end.take().unwrap(); + let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( ctx, self, - old_start.custom.unwrap(), - )? - } else { - self.timeline.emit_now_mu(ctx) - }; + end_expr.custom.unwrap(), + )?; - // Emulate variable allocation, as we need to use the CodeGenContext - // HashMap to store our variable due to lifetime limitation - // Note: we should be able to store variables directly if generic - // associative type is used by limiting the lifetime of CodeGenerator to - // the LLVM Context. - // The name is guaranteed to be unique as users cannot use this as variable - // name. - self.start = old_start.clone().map_or_else( - || { - let start = format!("with-{}-start", self.name_counter).into(); - let start_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: start, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let start = self - .gen_store_target(ctx, &start_expr, Some("start.addr"))? - .unwrap(); - ctx.builder.build_store(start, now).unwrap(); - Ok(Some(start_expr)) as Result<_, String> - }, - |v| Ok(Some(v)), - )?; - let end = format!("with-{}-end", self.name_counter).into(); - let end_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: end, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); - ctx.builder.build_store(end, now).unwrap(); - self.end = Some(end_expr); - self.name_counter += 1; - self.parallel_mode = match id.to_string().as_str() { - "parallel" => ParallelMode::Deep, - "legacy_parallel" => ParallelMode::Legacy, - _ => unreachable!(), - }; + // inside a sequential block + if old_start.is_none() { + self.timeline.emit_at_mu(ctx, end_val); + } - self.gen_block(ctx, body.iter())?; + // inside a parallel block, should update the outer max now_mu + self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - let current = ctx.builder.get_insert_block().unwrap(); + self.parallel_mode = old_parallel_mode; + self.end = old_end; + self.start = old_start; - // if the current block is terminated, move before the terminator - // we want to set the timeline before reaching the terminator - // TODO: This may be unsound if there are multiple exit paths in the - // block... e.g. - // if ...: - // return - // Perhaps we can fix this by using actual with block? - let reset_position = if let Some(terminator) = current.get_terminator() { - ctx.builder.position_before(&terminator); - true - } else { - false - }; + if reset_position { + ctx.builder.position_at_end(current); + } - // set duration - let end_expr = self.end.take().unwrap(); - let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( - ctx, - self, - end_expr.custom.unwrap(), - )?; + return Ok(()); + } else if python_id == self.special_ids.sequential { + // For deep parallel, temporarily take away start to avoid function calls in + // the block from resetting the timeline. + // This does not affect legacy parallel, as the timeline will be reset after + // this block finishes execution. + let start = self.start.take(); + self.gen_block(ctx, body.iter())?; + self.start = start; - // inside a sequential block - if old_start.is_none() { - self.timeline.emit_at_mu(ctx, end_val); + // Reset the timeline when we are exiting the sequential block + // Legacy parallel does not need this, since it will be reset after codegen + // for this statement is completed + if self.parallel_mode == ParallelMode::Deep { + self.timeline_reset_start(ctx)?; + } + + return Ok(()); } - - // inside a parallel block, should update the outer max now_mu - self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - - self.parallel_mode = old_parallel_mode; - self.end = old_end; - self.start = old_start; - - if reset_position { - ctx.builder.position_at_end(current); - } - - return Ok(()); - } else if id == &"sequential".into() { - // For deep parallel, temporarily take away start to avoid function calls in - // the block from resetting the timeline. - // This does not affect legacy parallel, as the timeline will be reset after - // this block finishes execution. - let start = self.start.take(); - self.gen_block(ctx, body.iter())?; - self.start = start; - - // Reset the timeline when we are exiting the sequential block - // Legacy parallel does not need this, since it will be reset after codegen - // for this statement is completed - if self.parallel_mode == ParallelMode::Deep { - self.timeline_reset_start(ctx)?; - } - - return Ok(()); } } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ba6c4fae..d4136a07 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -162,6 +162,13 @@ pub struct PrimitivePythonId { module: u64, } +#[derive(Clone, Default)] +pub struct SpecialPythonId { + parallel: u64, + legacy_parallel: u64, + sequential: u64, +} + type TopLevelComponent = (Stmt, String, PyObject); // TopLevelComposer is unsendable as it holds the unification table, which is @@ -179,6 +186,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, + special_ids: SpecialPythonId, /// LLVM-related options for code generation. llvm_options: CodeGenLLVMOptions, } @@ -797,6 +805,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), )) }) .collect(); @@ -813,6 +822,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); @@ -1192,6 +1202,7 @@ impl Nac3 { string_store: Arc::new(string_store.into()), exception_ids: Arc::default(), deferred_eval_store: DeferredEvaluationStore::new(), + special_ids: Default::default(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, target: isa.get_llvm_target_options(), @@ -1203,6 +1214,7 @@ impl Nac3 { &mut self, functions: &PySet, classes: &PySet, + special_ids: &PyDict, content_modules: &PySet, ) -> PyResult<()> { let (modules, class_ids) = @@ -1236,6 +1248,25 @@ impl Nac3 { for module in modules.into_values() { self.register_module(&module, &class_ids)?; } + + self.special_ids = SpecialPythonId { + parallel: special_ids.get_item("parallel").ok().flatten().unwrap().extract().unwrap(), + legacy_parallel: special_ids + .get_item("legacy_parallel") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + sequential: special_ids + .get_item("sequential") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + }; + Ok(()) }