From 84c520124323f2b8f11e1aba4909bdc1675d59cc Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sun, 31 Oct 2021 17:16:21 +0800 Subject: [PATCH] with parallel/sequential support Behavior of parallel and sequential: Each function call (indirectly, can be inside a sequential block) within a parallel block will update the end variable to the maximum now_mu in the block. Each function call directly inside a parallel block will reset the timeline after execution. A parallel block within a sequential block (or not within any block) will set the timeline to the max now_mu within the block (and the outer max now_mu will also be updated). Implementation: We track the start and end separately. - If there is a start variable, it indicates that we are directly inside a parallel block and we have to reset the timeline after every function call. - If there is a end variable, it indicates that we are (indirectly) inside a parallel block, and we should update the max end value. Note: requires testing, it is difficult to inspect the output IR --- nac3artiq/min_artiq.py | 26 ++- nac3artiq/src/codegen.rs | 196 ++++++++++++++++++ nac3artiq/src/lib.rs | 14 +- nac3core/src/codegen/generator.rs | 8 + nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/stmt.rs | 10 + nac3core/src/typecheck/function_check.rs | 11 + nac3core/src/typecheck/type_inferencer/mod.rs | 91 +++++++- 8 files changed, 347 insertions(+), 13 deletions(-) create mode 100644 nac3artiq/src/codegen.rs diff --git a/nac3artiq/min_artiq.py b/nac3artiq/min_artiq.py index fb743faa..18355dcf 100644 --- a/nac3artiq/min_artiq.py +++ b/nac3artiq/min_artiq.py @@ -5,7 +5,8 @@ from numpy import int32, int64 import nac3artiq -__all__ = ["KernelInvariant", "extern", "kernel", "portable", "ms", "us", "ns", "Core", "TTLOut"] +__all__ = ["KernelInvariant", "extern", "kernel", "portable", "ms", "us", "ns", + "Core", "TTLOut", "parallel", "sequential"] import device_db @@ -15,7 +16,6 @@ nac3 = nac3artiq.NAC3(core_arguments["target"]) allow_module_registration = True registered_ids = set() - def KernelInvariant(t): return t @@ -83,6 +83,14 @@ def rtio_input_timestamp(timeout_mu: int64, channel: int32) -> int64: def rtio_input_data(channel: int32) -> int32: raise NotImplementedError("syscall not simulated") +def at_mu(_): + raise NotImplementedError("at_mu not simulated") + +def now_mu() -> int32: + raise NotImplementedError("now_mu not simulated") + +def delay_mu(_): + raise NotImplementedError("delay_mu not simulated") @kernel class Core: @@ -169,3 +177,17 @@ class TTLOut: self.on() self.core.delay(duration) self.off() + +@portable +class KernelContextManager: + @kernel + def __enter__(self): + pass + + @kernel + def __exit__(self): + pass + +parallel = KernelContextManager() +sequential = KernelContextManager() + diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs new file mode 100644 index 00000000..cb79c549 --- /dev/null +++ b/nac3artiq/src/codegen.rs @@ -0,0 +1,196 @@ +use nac3core::{ + codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator}, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type}, +}; + +use rustpython_parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; + +use inkwell::values::BasicValueEnum; + +use crate::timeline::TimeFns; + +pub struct ArtiqCodeGenerator<'a> { + name: String, + name_counter: u32, + start: Option>>, + end: Option>>, + timeline: &'a (dyn TimeFns + Sync), +} + +impl<'a> ArtiqCodeGenerator<'a> { + pub fn new(name: String, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> { + ArtiqCodeGenerator { + name, + name_counter: 0, + start: None, + end: None, + timeline, + } + } +} + +impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { + fn get_name(&self) -> &str { + &self.name + } + + fn gen_call<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + params: Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> { + let result = gen_call(self, ctx, obj, fun, params); + if let Some(end) = self.end.clone() { + let old_end = self.gen_expr(ctx, &end).unwrap(); + let now = self.timeline.emit_now_mu(ctx); + let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { + let i64 = ctx.ctx.i64_type(); + ctx.module.add_function( + "llvm.smax.i64", + i64.fn_type(&[i64.into(), i64.into()], false), + None, + ) + }); + let max = ctx + .builder + .build_call(smax, &[old_end, now], "smax") + .try_as_basic_value() + .left() + .unwrap(); + let end_store = self.gen_store_target(ctx, &end); + ctx.builder.build_store(end_store, max); + } + if let Some(start) = self.start.clone() { + let start_val = self.gen_expr(ctx, &start).unwrap(); + self.timeline.emit_at_mu(ctx, start_val); + } + result + } + + fn gen_with<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + if let StmtKind::With { items, body, .. } = &stmt.node { + if items.len() == 1 && items[0].optional_vars.is_none() { + let item = &items[0]; + // Behavior of parallel and sequential: + // Each function call (indirectly, can be inside a sequential block) within a parallel + // block will update the end variable to the maximum now_mu in the block. + // Each function call directly inside a parallel block will reset the timeline after + // execution. A parallel block within a sequential block (or not within any block) will + // set the timeline to the max now_mu within the block (and the outer max now_mu will also + // be updated). + // + // Implementation: We track the start and end separately. + // - If there is a start variable, it indicates that we are directly inside a + // parallel block and we have to reset the timeline after every function call. + // - If there is a end variable, it indicates that we are (indirectly) inside a + // parallel block, and we should update the max end value. + if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { + if id == &"parallel".into() { + let old_start = self.start.take(); + let old_end = self.end.take(); + let now = if let Some(old_start) = &old_start { + self.gen_expr(ctx, old_start).unwrap() + } else { + self.timeline.emit_now_mu(ctx) + }; + // Emulate variable allocation, as we need to use the CodeGenContext + // HashMap to store our variable due to lifetime limitation + // Note: we should be able to store variables directly if generic + // associative type is used by limiting the lifetime of CodeGenerator to + // the LLVM Context. + // The name is guaranteed to be unique as users cannot use this as variable + // name. + self.start = old_start.clone().or_else(|| { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { + id: start, + ctx: name_ctx.clone(), + }, + custom: Some(ctx.primitives.int64), + }; + let start = self.gen_store_target(ctx, &start_expr); + ctx.builder.build_store(start, now); + Some(start_expr) + }); + let end = format!("with-{}-end", self.name_counter).into(); + let end_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { + id: end, + ctx: name_ctx.clone(), + }, + custom: Some(ctx.primitives.int64), + }; + let end = self.gen_store_target(ctx, &end_expr); + ctx.builder.build_store(end, now); + self.end = Some(end_expr); + self.name_counter += 1; + let mut exited = false; + for stmt in body.iter() { + if self.gen_stmt(ctx, stmt) { + exited = true; + break; + } + } + // set duration + let end_expr = self.end.take().unwrap(); + let end_val = self.gen_expr(ctx, &end_expr).unwrap(); + + // inside an sequential block + if old_start.is_none() { + self.timeline.emit_at_mu(ctx, end_val); + } + // inside a parallel block, should update the outer max now_mu + if let Some(old_end) = &old_end { + let outer_end_val = self.gen_expr(ctx, old_end).unwrap(); + let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { + let i64 = ctx.ctx.i64_type(); + ctx.module.add_function( + "llvm.smax.i64", + i64.fn_type(&[i64.into(), i64.into()], false), + None, + ) + }); + let max = ctx + .builder + .build_call(smax, &[end_val, outer_end_val], "smax") + .try_as_basic_value() + .left() + .unwrap(); + let outer_end = self.gen_store_target(ctx, old_end); + ctx.builder.build_store(outer_end, max); + } + self.start = old_start; + self.end = old_end; + return exited; + } else if id == &"sequential".into() { + let start = self.start.take(); + for stmt in body.iter() { + if self.gen_stmt(ctx, stmt) { + self.start = start; + return true; + } + } + self.start = start; + return false; + } + } + } + // not parallel/sequential + gen_with(self, ctx, stmt) + } else { + unreachable!() + } + } +} diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 8d9d1e12..ea238189 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -19,18 +19,16 @@ use rustpython_parser::{ use parking_lot::{Mutex, RwLock}; use nac3core::{ - codegen::{ - concrete_type::ConcreteTypeStore, CodeGenTask, DefaultCodeGenerator, WithCall, - WorkerRegistry, - }, + codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, DefinitionId, GenCall, TopLevelContext, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, typecheck::{type_inferencer::PrimitiveStore, typedef::Type}, }; -use crate::symbol_resolver::Resolver; +use crate::{codegen::ArtiqCodeGenerator, symbol_resolver::Resolver}; +mod codegen; mod symbol_resolver; mod timeline; @@ -436,10 +434,14 @@ impl Nac3 { ) .expect("couldn't write module to file"); }))); + let time_fns: &(dyn TimeFns + Sync) = match isa { + Isa::RiscV => &timeline::NOW_PINNING_TIME_FNS, + Isa::CortexA9 => &timeline::EXTERN_TIME_FNS, + }; let thread_names: Vec = (0..4).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| Box::new(DefaultCodeGenerator::new(s.to_string()))) + .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), time_fns))) .collect(); py.allow_threads(|| { diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 53049795..22ee541f 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -125,6 +125,14 @@ pub trait CodeGenerator { gen_if(self, ctx, stmt) } + fn gen_with<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + gen_with(self, ctx, stmt) + } + /// Generate code for a statement /// Return true if the statement must early return fn gen_stmt<'ctx, 'a>( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 17396417..e93dff0a 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -28,9 +28,9 @@ use std::sync::{ use std::thread; pub mod concrete_type; -mod expr; +pub mod expr; +pub mod stmt; mod generator; -mod stmt; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 86722e0d..41779259 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -291,6 +291,15 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( } } +pub fn gen_with<'ctx, 'a, G: CodeGenerator + ?Sized>( + _: &mut G, + _: &mut CodeGenContext<'ctx, 'a>, + _: &Stmt>, +) -> bool { + // TODO: Implement with statement after finishing exceptions + unimplemented!() +} + pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -330,6 +339,7 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( StmtKind::If { .. } => return generator.gen_if(ctx, stmt), StmtKind::While { .. } => return generator.gen_while(ctx, stmt), StmtKind::For { .. } => return generator.gen_for(ctx, stmt), + StmtKind::With { .. } => return generator.gen_with(ctx, stmt), _ => unimplemented!(), }; false diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index dcb20aea..096ce244 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -219,6 +219,17 @@ impl<'a> Inferencer<'a> { self.check_block(orelse, &mut defined_identifiers)?; Ok(false) } + StmtKind::With { items, body, .. } => { + let mut new_defined_identifiers = defined_identifiers.clone(); + for item in items.iter() { + self.check_expr(&item.context_expr, defined_identifiers)?; + if let Some(var) = item.optional_vars.as_ref() { + self.check_pattern(var, &mut new_defined_identifiers)?; + } + } + self.check_block(body, &mut new_defined_identifiers)?; + Ok(false) + } StmtKind::Expr { value } => { self.check_expr(value, defined_identifiers)?; Ok(false) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7f42d454..ce3698a3 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -64,6 +64,10 @@ impl fold::Fold<()> for NaiveFolder { } } +fn report_error(msg: &str, location: Location) -> Result { + Err(format!("{} at {}", msg, location)) +} + impl<'a> fold::Fold<()> for Inferencer<'a> { type TargetU = Option; type Error = String; @@ -165,6 +169,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } fold::fold_stmt(self, node)? } + ast::StmtKind::With { ref items, .. } => { + for item in items.iter() { + if let Some(var) = &item.optional_vars { + self.infer_pattern(var)?; + } + } + fold::fold_stmt(self, node)? + } _ => fold::fold_stmt(self, node)?, }; match &stmt.node { @@ -186,19 +198,92 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} ast::StmtKind::Break | ast::StmtKind::Continue | ast::StmtKind::Pass => {} + ast::StmtKind::With { items, .. } => { + for item in items.iter() { + let ty = item.context_expr.custom.unwrap(); + // if we can simply unify without creating new types... + let mut fast_path = false; + if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { + let fields = fields.borrow(); + fast_path = true; + if let Some(enter) = fields.get(&"__enter__".into()).cloned() { + if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) { + let signature = signature.borrow(); + if !signature.args.is_empty() { + return report_error( + "__enter__ method should take no argument other than self", + stmt.location + ) + } + if let Some(var) = &item.optional_vars { + if signature.vars.is_empty() { + self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?; + } else { + fast_path = false; + } + } + } else { + fast_path = false; + } + } else { + return report_error( + "__enter__ method is required for context manager", + stmt.location + ); + } + if let Some(exit) = fields.get(&"__exit__".into()).cloned() { + if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) { + let signature = signature.borrow(); + if !signature.args.is_empty() { + return report_error( + "__exit__ method should take no argument other than self", + stmt.location + ) + } + } else { + fast_path = false; + } + } else { + return report_error( + "__exit__ method is required for context manager", + stmt.location + ); + } + } + if !fast_path { + let enter = TypeEnum::TFunc(RefCell::new(FunSignature { + args: vec![], + ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()), + vars: Default::default() + })); + let enter = self.unifier.add_ty(enter); + let exit = TypeEnum::TFunc(RefCell::new(FunSignature { + args: vec![], + ret: self.unifier.get_fresh_var().0, + vars: Default::default() + })); + let exit = self.unifier.add_ty(exit); + let mut fields = HashMap::new(); + fields.insert("__enter__".into(), enter); + fields.insert("__exit__".into(), exit); + let record = self.unifier.add_record(fields); + self.unify(ty, record, &stmt.location)?; + } + } + } ast::StmtKind::Return { value } => match (value, self.function_data.return_type) { (Some(v), Some(v1)) => { self.unify(v.custom.unwrap(), v1, &v.location)?; } (Some(_), None) => { - return Err("Unexpected return value".to_string()); + return report_error("Unexpected return value", stmt.location); } (None, Some(_)) => { - return Err("Expected return value".to_string()); + return report_error("Expected return value", stmt.location); } (None, None) => {} }, - _ => return Err("Unsupported statement type".to_string()), + _ => return report_error("Unsupported statement type", stmt.location), }; Ok(stmt) }