2021-10-31 17:16:21 +08:00
|
|
|
use nac3core::{
|
2022-02-12 21:17:37 +08:00
|
|
|
codegen::{
|
|
|
|
expr::gen_call,
|
2024-02-22 01:47:26 +08:00
|
|
|
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
|
2023-10-25 16:16:20 +08:00
|
|
|
stmt::{gen_block, gen_with},
|
2022-02-12 21:17:37 +08:00
|
|
|
CodeGenContext, CodeGenerator,
|
|
|
|
},
|
2021-12-27 22:55:51 +08:00
|
|
|
symbol_resolver::ValueEnum,
|
2024-06-12 15:01:01 +08:00
|
|
|
toplevel::{helper::PrimDef, DefinitionId, GenCall},
|
2024-06-12 14:45:03 +08:00
|
|
|
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
2021-10-31 17:16:21 +08:00
|
|
|
};
|
|
|
|
|
2021-11-03 17:11:00 +08:00
|
|
|
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
2021-10-31 17:16:21 +08:00
|
|
|
|
2022-02-12 21:17:37 +08:00
|
|
|
use inkwell::{
|
2024-06-12 14:45:03 +08:00
|
|
|
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
|
2022-02-12 21:17:37 +08:00
|
|
|
};
|
2021-10-31 17:16:21 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
use pyo3::{
|
|
|
|
types::{PyDict, PyList},
|
|
|
|
PyObject, PyResult, Python,
|
|
|
|
};
|
2022-03-25 22:42:01 +08:00
|
|
|
|
|
|
|
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
2021-10-31 17:16:21 +08:00
|
|
|
|
2024-06-17 15:01:22 +08:00
|
|
|
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
2022-02-12 21:17:37 +08:00
|
|
|
use std::{
|
|
|
|
collections::hash_map::DefaultHasher,
|
|
|
|
collections::HashMap,
|
|
|
|
hash::{Hash, Hasher},
|
|
|
|
sync::Arc,
|
|
|
|
};
|
|
|
|
|
2023-10-25 16:16:20 +08:00
|
|
|
/// The parallelism mode within a block.
|
|
|
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
|
|
|
enum ParallelMode {
|
|
|
|
/// No parallelism is currently registered for this context.
|
|
|
|
None,
|
|
|
|
|
|
|
|
/// Legacy (or shallow) parallelism. Default before NAC3.
|
|
|
|
///
|
|
|
|
/// Each statement within the `with` block is treated as statements to be executed in parallel.
|
|
|
|
Legacy,
|
|
|
|
|
|
|
|
/// Deep parallelism. Default since NAC3.
|
|
|
|
///
|
|
|
|
/// Each function call within the `with` block (except those within a nested `sequential` block)
|
|
|
|
/// are treated to be executed in parallel.
|
2024-06-12 14:45:03 +08:00
|
|
|
Deep,
|
2023-10-25 16:16:20 +08:00
|
|
|
}
|
|
|
|
|
2021-10-31 17:16:21 +08:00
|
|
|
pub struct ArtiqCodeGenerator<'a> {
|
|
|
|
name: String,
|
2023-10-23 13:35:29 +08:00
|
|
|
|
|
|
|
/// The size of a `size_t` variable in bits.
|
2021-12-27 22:55:51 +08:00
|
|
|
size_t: u32,
|
2023-10-23 13:35:29 +08:00
|
|
|
|
|
|
|
/// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks.
|
2021-10-31 17:16:21 +08:00
|
|
|
name_counter: u32,
|
2023-10-23 13:35:29 +08:00
|
|
|
|
|
|
|
/// Variable for tracking the start of a `with parallel` block.
|
2021-10-31 17:16:21 +08:00
|
|
|
start: Option<Expr<Option<Type>>>,
|
2023-10-23 13:35:29 +08:00
|
|
|
|
|
|
|
/// Variable for tracking the end of a `with parallel` block.
|
2021-10-31 17:16:21 +08:00
|
|
|
end: Option<Expr<Option<Type>>>,
|
|
|
|
timeline: &'a (dyn TimeFns + Sync),
|
2023-10-25 16:16:20 +08:00
|
|
|
|
|
|
|
/// The [ParallelMode] of the current parallel context.
|
|
|
|
///
|
|
|
|
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
|
|
|
/// statement, which is used to determine when and how the timeline should be updated.
|
|
|
|
parallel_mode: ParallelMode,
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> ArtiqCodeGenerator<'a> {
|
2022-02-12 21:17:37 +08:00
|
|
|
pub fn new(
|
|
|
|
name: String,
|
|
|
|
size_t: u32,
|
|
|
|
timeline: &'a (dyn TimeFns + Sync),
|
|
|
|
) -> ArtiqCodeGenerator<'a> {
|
2021-12-27 22:55:51 +08:00
|
|
|
assert!(size_t == 32 || size_t == 64);
|
2023-10-25 15:54:27 +08:00
|
|
|
ArtiqCodeGenerator {
|
|
|
|
name,
|
|
|
|
size_t,
|
|
|
|
name_counter: 0,
|
|
|
|
start: None,
|
|
|
|
end: None,
|
|
|
|
timeline,
|
2023-10-25 16:16:20 +08:00
|
|
|
parallel_mode: ParallelMode::None,
|
2023-10-25 15:54:27 +08:00
|
|
|
}
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
|
|
|
|
2023-10-24 19:08:23 +08:00
|
|
|
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
|
|
|
|
/// position of the timeline to the initial timeline position before entering the `parallel`
|
|
|
|
/// block.
|
|
|
|
///
|
|
|
|
/// Direct-`parallel` block context refers to when the generator is generating statements whose
|
|
|
|
/// closest parent `with` statement is a `with parallel` block.
|
2024-06-12 14:45:03 +08:00
|
|
|
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
|
2023-10-24 19:08:23 +08:00
|
|
|
if let Some(start) = self.start.clone() {
|
2024-06-12 14:45:03 +08:00
|
|
|
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
|
|
|
|
ctx,
|
|
|
|
self,
|
|
|
|
start.custom.unwrap(),
|
|
|
|
)?;
|
2023-10-24 19:08:23 +08:00
|
|
|
self.timeline.emit_at_mu(ctx, start_val);
|
2021-12-27 22:55:51 +08:00
|
|
|
}
|
2023-10-24 19:08:23 +08:00
|
|
|
|
|
|
|
Ok(())
|
2021-12-27 22:55:51 +08:00
|
|
|
}
|
|
|
|
|
2023-10-24 19:08:23 +08:00
|
|
|
/// If the generator is currently in a `parallel` block context, emits IR that updates the
|
|
|
|
/// maximum end position of the `parallel` block as specified by the timeline `end` value.
|
|
|
|
///
|
|
|
|
/// In general the `end` parameter should be set to `self.end` for updating the maximum end
|
|
|
|
/// position for the current `parallel` block. Other values can be passed in to update the
|
|
|
|
/// maximum end position for other `parallel` blocks.
|
|
|
|
///
|
|
|
|
/// `parallel`-block context refers to when the generator is generating statements within a
|
|
|
|
/// (possibly indirect) `parallel` block.
|
2023-10-25 15:54:27 +08:00
|
|
|
///
|
|
|
|
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
|
|
|
|
/// the end of the provided value name.
|
2023-12-06 11:49:02 +08:00
|
|
|
fn timeline_update_end_max(
|
2021-10-31 17:16:21 +08:00
|
|
|
&mut self,
|
2023-12-06 11:49:02 +08:00
|
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
2023-10-24 19:08:23 +08:00
|
|
|
end: Option<Expr<Option<Type>>>,
|
2023-10-25 15:54:27 +08:00
|
|
|
store_name: Option<&str>,
|
2023-10-24 19:08:23 +08:00
|
|
|
) -> Result<(), String> {
|
|
|
|
if let Some(end) = end {
|
2024-06-12 14:45:03 +08:00
|
|
|
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
|
2023-10-25 15:54:27 +08:00
|
|
|
ctx,
|
2024-06-12 14:45:03 +08:00
|
|
|
self,
|
|
|
|
end.custom.unwrap(),
|
|
|
|
)?;
|
|
|
|
let now = self.timeline.emit_now_mu(ctx);
|
|
|
|
let max =
|
|
|
|
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
|
|
|
|
let end_store = self
|
|
|
|
.gen_store_target(
|
|
|
|
ctx,
|
|
|
|
&end,
|
|
|
|
store_name.map(|name| format!("{name}.addr")).as_deref(),
|
|
|
|
)?
|
2023-12-06 15:26:37 +08:00
|
|
|
.unwrap();
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_store(end_store, max).unwrap();
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
2023-10-24 19:08:23 +08:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
|
|
|
fn get_name(&self) -> &str {
|
|
|
|
&self.name
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
|
|
|
if self.size_t == 32 {
|
|
|
|
ctx.i32_type()
|
|
|
|
} else {
|
|
|
|
ctx.i64_type()
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
2023-10-24 19:08:23 +08:00
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
|
2023-10-25 16:16:20 +08:00
|
|
|
&mut self,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-06-12 14:45:03 +08:00
|
|
|
stmts: I,
|
|
|
|
) -> Result<(), String>
|
|
|
|
where
|
|
|
|
Self: Sized,
|
|
|
|
{
|
2023-10-25 16:16:20 +08:00
|
|
|
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
|
|
|
|
// in the parallel block
|
|
|
|
if self.parallel_mode == ParallelMode::Legacy {
|
|
|
|
for stmt in stmts {
|
|
|
|
self.gen_stmt(ctx, stmt)?;
|
|
|
|
|
|
|
|
if ctx.is_terminated() {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
|
|
self.timeline_reset_start(ctx)?;
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
} else {
|
|
|
|
gen_block(self, ctx, stmts)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-06 11:49:02 +08:00
|
|
|
fn gen_call<'ctx>(
|
2023-10-24 19:08:23 +08:00
|
|
|
&mut self,
|
2023-12-06 11:49:02 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
2023-10-24 19:08:23 +08:00
|
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
|
|
|
let result = gen_call(self, ctx, obj, fun, params)?;
|
|
|
|
|
2023-10-25 16:16:20 +08:00
|
|
|
// Deep parallel emits timeline end-update/timeline-reset after each function call
|
|
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
|
|
self.timeline_reset_start(ctx)?;
|
|
|
|
}
|
2023-10-24 19:08:23 +08:00
|
|
|
|
2022-02-21 17:52:34 +08:00
|
|
|
Ok(result)
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
|
|
|
|
2023-12-06 11:49:02 +08:00
|
|
|
fn gen_with(
|
2021-10-31 17:16:21 +08:00
|
|
|
&mut self,
|
2023-12-06 11:49:02 +08:00
|
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
2021-10-31 17:16:21 +08:00
|
|
|
stmt: &Stmt<Option<Type>>,
|
2022-02-21 17:52:34 +08:00
|
|
|
) -> Result<(), String> {
|
2024-06-12 14:45:03 +08:00
|
|
|
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
|
2023-12-12 13:38:27 +08:00
|
|
|
|
|
|
|
if items.len() == 1 && items[0].optional_vars.is_none() {
|
|
|
|
let item = &items[0];
|
|
|
|
|
|
|
|
// Behavior of parallel and sequential:
|
|
|
|
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
|
|
|
// block will update the end variable to the maximum now_mu in the block.
|
|
|
|
// Each function call directly inside a parallel block will reset the timeline after
|
|
|
|
// execution. A parallel block within a sequential block (or not within any block) will
|
|
|
|
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
|
|
|
// be updated).
|
|
|
|
//
|
|
|
|
// Implementation: We track the start and end separately.
|
|
|
|
// - If there is a start variable, it indicates that we are directly inside a
|
|
|
|
// parallel block and we have to reset the timeline after every function call.
|
|
|
|
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
|
|
|
// parallel block, and we should update the max end value.
|
|
|
|
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
|
|
|
if id == &"parallel".into() || id == &"legacy_parallel".into() {
|
|
|
|
let old_start = self.start.take();
|
|
|
|
let old_end = self.end.take();
|
|
|
|
let old_parallel_mode = self.parallel_mode;
|
|
|
|
|
|
|
|
let now = if let Some(old_start) = &old_start {
|
2024-06-12 14:45:03 +08:00
|
|
|
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
|
|
|
|
ctx,
|
|
|
|
self,
|
|
|
|
old_start.custom.unwrap(),
|
|
|
|
)?
|
2023-12-12 13:38:27 +08:00
|
|
|
} else {
|
|
|
|
self.timeline.emit_now_mu(ctx)
|
|
|
|
};
|
|
|
|
|
|
|
|
// Emulate variable allocation, as we need to use the CodeGenContext
|
|
|
|
// HashMap to store our variable due to lifetime limitation
|
|
|
|
// Note: we should be able to store variables directly if generic
|
|
|
|
// associative type is used by limiting the lifetime of CodeGenerator to
|
|
|
|
// the LLVM Context.
|
|
|
|
// The name is guaranteed to be unique as users cannot use this as variable
|
|
|
|
// name.
|
|
|
|
self.start = old_start.clone().map_or_else(
|
|
|
|
|| {
|
|
|
|
let start = format!("with-{}-start", self.name_counter).into();
|
|
|
|
let start_expr = Located {
|
|
|
|
// location does not matter at this point
|
|
|
|
location: stmt.location,
|
2024-06-12 15:13:09 +08:00
|
|
|
node: ExprKind::Name { id: start, ctx: *name_ctx },
|
2023-12-12 13:38:27 +08:00
|
|
|
custom: Some(ctx.primitives.int64),
|
|
|
|
};
|
|
|
|
let start = self
|
|
|
|
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
|
|
|
|
.unwrap();
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_store(start, now).unwrap();
|
2023-12-12 13:38:27 +08:00
|
|
|
Ok(Some(start_expr)) as Result<_, String>
|
|
|
|
},
|
|
|
|
|v| Ok(Some(v)),
|
|
|
|
)?;
|
|
|
|
let end = format!("with-{}-end", self.name_counter).into();
|
|
|
|
let end_expr = Located {
|
|
|
|
// location does not matter at this point
|
|
|
|
location: stmt.location,
|
2024-06-12 15:13:09 +08:00
|
|
|
node: ExprKind::Name { id: end, ctx: *name_ctx },
|
2023-12-12 13:38:27 +08:00
|
|
|
custom: Some(ctx.primitives.int64),
|
|
|
|
};
|
2024-06-12 14:45:03 +08:00
|
|
|
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_store(end, now).unwrap();
|
2023-12-12 13:38:27 +08:00
|
|
|
self.end = Some(end_expr);
|
|
|
|
self.name_counter += 1;
|
|
|
|
self.parallel_mode = match id.to_string().as_str() {
|
|
|
|
"parallel" => ParallelMode::Deep,
|
|
|
|
"legacy_parallel" => ParallelMode::Legacy,
|
|
|
|
_ => unreachable!(),
|
|
|
|
};
|
|
|
|
|
|
|
|
self.gen_block(ctx, body.iter())?;
|
|
|
|
|
|
|
|
let current = ctx.builder.get_insert_block().unwrap();
|
|
|
|
|
|
|
|
// if the current block is terminated, move before the terminator
|
|
|
|
// we want to set the timeline before reaching the terminator
|
|
|
|
// TODO: This may be unsound if there are multiple exit paths in the
|
|
|
|
// block... e.g.
|
|
|
|
// if ...:
|
|
|
|
// return
|
|
|
|
// Perhaps we can fix this by using actual with block?
|
|
|
|
let reset_position = if let Some(terminator) = current.get_terminator() {
|
|
|
|
ctx.builder.position_before(&terminator);
|
|
|
|
true
|
|
|
|
} else {
|
|
|
|
false
|
|
|
|
};
|
|
|
|
|
|
|
|
// set duration
|
|
|
|
let end_expr = self.end.take().unwrap();
|
2024-06-12 14:45:03 +08:00
|
|
|
let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
|
|
|
|
ctx,
|
|
|
|
self,
|
|
|
|
end_expr.custom.unwrap(),
|
|
|
|
)?;
|
2023-12-12 13:38:27 +08:00
|
|
|
|
|
|
|
// inside a sequential block
|
|
|
|
if old_start.is_none() {
|
|
|
|
self.timeline.emit_at_mu(ctx, end_val);
|
|
|
|
}
|
|
|
|
|
|
|
|
// inside a parallel block, should update the outer max now_mu
|
|
|
|
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
|
|
|
|
|
|
|
|
self.parallel_mode = old_parallel_mode;
|
|
|
|
self.end = old_end;
|
|
|
|
self.start = old_start;
|
|
|
|
|
|
|
|
if reset_position {
|
|
|
|
ctx.builder.position_at_end(current);
|
|
|
|
}
|
|
|
|
|
|
|
|
return Ok(());
|
|
|
|
} else if id == &"sequential".into() {
|
|
|
|
// For deep parallel, temporarily take away start to avoid function calls in
|
|
|
|
// the block from resetting the timeline.
|
|
|
|
// This does not affect legacy parallel, as the timeline will be reset after
|
|
|
|
// this block finishes execution.
|
|
|
|
let start = self.start.take();
|
|
|
|
self.gen_block(ctx, body.iter())?;
|
|
|
|
self.start = start;
|
|
|
|
|
|
|
|
// Reset the timeline when we are exiting the sequential block
|
|
|
|
// Legacy parallel does not need this, since it will be reset after codegen
|
|
|
|
// for this statement is completed
|
|
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
|
|
self.timeline_reset_start(ctx)?;
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
2023-12-12 13:38:27 +08:00
|
|
|
|
|
|
|
return Ok(());
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-12-12 13:38:27 +08:00
|
|
|
|
|
|
|
// not parallel/sequential
|
|
|
|
gen_with(self, ctx, stmt)
|
2021-10-31 17:16:21 +08:00
|
|
|
}
|
|
|
|
}
|
2022-02-12 21:17:37 +08:00
|
|
|
|
2023-12-06 11:49:02 +08:00
|
|
|
fn gen_rpc_tag(
|
|
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
2022-02-21 18:27:46 +08:00
|
|
|
ty: Type,
|
|
|
|
buffer: &mut Vec<u8>,
|
|
|
|
) -> Result<(), String> {
|
2022-02-12 21:17:37 +08:00
|
|
|
use nac3core::typecheck::typedef::TypeEnum::*;
|
|
|
|
|
|
|
|
let int32 = ctx.primitives.int32;
|
|
|
|
let int64 = ctx.primitives.int64;
|
|
|
|
let float = ctx.primitives.float;
|
|
|
|
let bool = ctx.primitives.bool;
|
|
|
|
let str = ctx.primitives.str;
|
|
|
|
let none = ctx.primitives.none;
|
|
|
|
|
|
|
|
if ctx.unifier.unioned(ty, int32) {
|
|
|
|
buffer.push(b'i');
|
|
|
|
} else if ctx.unifier.unioned(ty, int64) {
|
|
|
|
buffer.push(b'I');
|
|
|
|
} else if ctx.unifier.unioned(ty, float) {
|
|
|
|
buffer.push(b'f');
|
|
|
|
} else if ctx.unifier.unioned(ty, bool) {
|
|
|
|
buffer.push(b'b');
|
|
|
|
} else if ctx.unifier.unioned(ty, str) {
|
|
|
|
buffer.push(b's');
|
|
|
|
} else if ctx.unifier.unioned(ty, none) {
|
|
|
|
buffer.push(b'n');
|
|
|
|
} else {
|
2022-02-21 17:52:34 +08:00
|
|
|
let ty_enum = ctx.unifier.get_ty(ty);
|
|
|
|
match &*ty_enum {
|
2022-02-12 21:17:37 +08:00
|
|
|
TTuple { ty } => {
|
|
|
|
buffer.push(b't');
|
|
|
|
buffer.push(ty.len() as u8);
|
|
|
|
for ty in ty {
|
2022-02-21 17:52:34 +08:00
|
|
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
TList { ty } => {
|
|
|
|
buffer.push(b'l');
|
2022-02-21 17:52:34 +08:00
|
|
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
2024-06-17 15:01:22 +08:00
|
|
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
|
|
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
|
|
let ndarray_ndims = if let TLiteral { values, .. } =
|
|
|
|
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
|
|
|
{
|
|
|
|
if values.len() != 1 {
|
|
|
|
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
|
|
|
}
|
|
|
|
|
|
|
|
let value = values[0].clone();
|
|
|
|
u64::try_from(value.clone())
|
|
|
|
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
|
|
|
} else {
|
|
|
|
unreachable!()
|
|
|
|
};
|
|
|
|
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
|
|
|
|
|
|
|
buffer.push(b'a');
|
|
|
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
|
|
|
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
|
|
|
}
|
2022-02-21 17:52:34 +08:00
|
|
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
}
|
2022-02-21 17:52:34 +08:00
|
|
|
Ok(())
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-06 11:49:02 +08:00
|
|
|
fn rpc_codegen_callback_fn<'ctx>(
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
2022-02-12 21:17:37 +08:00
|
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
|
|
generator: &mut dyn CodeGenerator,
|
2022-02-21 17:52:34 +08:00
|
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
2023-10-26 13:52:40 +08:00
|
|
|
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
2022-02-12 21:17:37 +08:00
|
|
|
let size_type = generator.get_size_type(ctx.ctx);
|
|
|
|
let int8 = ctx.ctx.i8_type();
|
|
|
|
let int32 = ctx.ctx.i32_type();
|
|
|
|
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
2022-02-12 21:17:37 +08:00
|
|
|
// -- setup rpc tags
|
|
|
|
let mut tag = Vec::new();
|
|
|
|
if obj.is_some() {
|
|
|
|
tag.push(b'O');
|
|
|
|
}
|
2023-12-11 15:04:35 +08:00
|
|
|
for arg in &fun.0.args {
|
2022-02-21 17:52:34 +08:00
|
|
|
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
tag.push(b':');
|
2022-02-21 17:52:34 +08:00
|
|
|
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
2022-02-12 21:17:37 +08:00
|
|
|
|
|
|
|
let mut hasher = DefaultHasher::new();
|
|
|
|
tag.hash(&mut hasher);
|
|
|
|
let hash = format!("{}", hasher.finish());
|
|
|
|
|
|
|
|
let tag_ptr = ctx
|
|
|
|
.module
|
|
|
|
.get_global(hash.as_str())
|
|
|
|
.unwrap_or_else(|| {
|
|
|
|
let tag_arr_ptr = ctx.module.add_global(
|
|
|
|
int8.array_type(tag.len() as u32),
|
|
|
|
None,
|
|
|
|
format!("tagptr{}", fun.1 .0).as_str(),
|
|
|
|
);
|
|
|
|
tag_arr_ptr.set_initializer(&int8.const_array(
|
2024-06-12 15:13:09 +08:00
|
|
|
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
2022-02-12 21:17:37 +08:00
|
|
|
));
|
|
|
|
tag_arr_ptr.set_linkage(Linkage::Private);
|
|
|
|
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
|
|
|
tag_ptr.set_linkage(Linkage::Private);
|
|
|
|
tag_ptr.set_initializer(&ctx.ctx.const_struct(
|
|
|
|
&[
|
|
|
|
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
|
|
|
|
size_type.const_int(tag.len() as u64, false).into(),
|
|
|
|
],
|
|
|
|
false,
|
|
|
|
));
|
|
|
|
tag_ptr
|
|
|
|
})
|
|
|
|
.as_pointer_value();
|
|
|
|
|
2023-12-11 15:04:35 +08:00
|
|
|
let arg_length = args.len() + usize::from(obj.is_some());
|
2022-02-12 21:17:37 +08:00
|
|
|
|
2024-02-22 01:47:26 +08:00
|
|
|
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
2024-06-12 14:45:03 +08:00
|
|
|
let args_ptr = ctx
|
|
|
|
.builder
|
2024-02-19 19:30:25 +08:00
|
|
|
.build_array_alloca(
|
|
|
|
ptr_type,
|
|
|
|
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
|
|
|
"argptr",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
|
|
|
|
// -- rpc args handling
|
|
|
|
let mut keys = fun.0.args.clone();
|
|
|
|
let mut mapping = HashMap::new();
|
2023-12-11 15:04:35 +08:00
|
|
|
for (key, value) in args {
|
2022-02-12 21:17:37 +08:00
|
|
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
|
|
|
}
|
|
|
|
// default value handling
|
2023-12-11 15:04:35 +08:00
|
|
|
for k in keys {
|
2024-06-12 14:45:03 +08:00
|
|
|
mapping
|
|
|
|
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
// reorder the parameters
|
|
|
|
let mut real_params = fun
|
|
|
|
.0
|
|
|
|
.args
|
|
|
|
.iter()
|
2022-04-08 03:26:42 +08:00
|
|
|
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
|
2022-02-28 23:09:14 +08:00
|
|
|
.collect::<Result<Vec<_>, _>>()?;
|
2022-02-12 21:17:37 +08:00
|
|
|
if let Some(obj) = obj {
|
|
|
|
if let ValueEnum::Static(obj) = obj.1 {
|
|
|
|
real_params.insert(0, obj.get_const_obj(ctx, generator));
|
|
|
|
} else {
|
|
|
|
// should be an error here...
|
|
|
|
panic!("only host object is allowed");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for (i, arg) in real_params.iter().enumerate() {
|
2024-06-12 14:45:03 +08:00
|
|
|
let arg_slot =
|
|
|
|
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_store(arg_slot, *arg).unwrap();
|
|
|
|
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
let arg_ptr = unsafe {
|
|
|
|
ctx.builder.build_gep(
|
|
|
|
args_ptr,
|
|
|
|
&[int32.const_int(i as u64, false)],
|
2023-12-11 15:04:35 +08:00
|
|
|
&format!("rpc.arg{i}"),
|
2022-02-12 21:17:37 +08:00
|
|
|
)
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
|
|
|
.unwrap();
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// call
|
|
|
|
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
|
|
|
ctx.module.add_function(
|
|
|
|
"rpc_send",
|
|
|
|
ctx.ctx.void_type().fn_type(
|
|
|
|
&[
|
|
|
|
int32.into(),
|
2023-01-12 19:31:03 +08:00
|
|
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
|
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
2022-02-12 21:17:37 +08:00
|
|
|
],
|
|
|
|
false,
|
|
|
|
),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
});
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder
|
2024-06-12 14:45:03 +08:00
|
|
|
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
2024-02-19 19:30:25 +08:00
|
|
|
.unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
|
|
|
|
// reclaim stack space used by arguments
|
2024-02-22 01:47:26 +08:00
|
|
|
call_stackrestore(ctx, stackptr);
|
2022-02-12 21:17:37 +08:00
|
|
|
|
|
|
|
// -- receive value:
|
|
|
|
// T result = {
|
|
|
|
// void *ret_ptr = alloca(sizeof(T));
|
|
|
|
// void *ptr = ret_ptr;
|
|
|
|
// loop: int size = rpc_recv(ptr);
|
|
|
|
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
|
|
|
// if(size) { ptr = alloca(size); goto loop; }
|
|
|
|
// else *(T*)ret_ptr
|
|
|
|
// }
|
|
|
|
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
|
|
|
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
|
|
|
|
});
|
|
|
|
|
|
|
|
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
|
|
|
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
|
2022-02-21 18:27:46 +08:00
|
|
|
return Ok(None);
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
|
|
|
let current_function = prehead_bb.get_parent().unwrap();
|
|
|
|
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
|
|
|
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
|
|
|
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
|
|
|
|
2023-09-20 13:34:50 +08:00
|
|
|
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
|
2022-02-12 21:17:37 +08:00
|
|
|
let need_load = !ret_ty.is_pointer_type();
|
2024-02-19 19:30:25 +08:00
|
|
|
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
|
|
|
|
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
|
|
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
ctx.builder.position_at_end(head_bb);
|
|
|
|
|
2024-02-19 19:30:25 +08:00
|
|
|
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
|
|
|
let alloc_size = ctx
|
|
|
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
2024-06-12 14:45:03 +08:00
|
|
|
let is_done = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
|
2024-02-19 19:30:25 +08:00
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
ctx.builder.position_at_end(alloc_bb);
|
|
|
|
|
2024-02-19 19:30:25 +08:00
|
|
|
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
|
|
|
|
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
2024-02-19 19:30:25 +08:00
|
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
2022-02-12 21:17:37 +08:00
|
|
|
|
|
|
|
ctx.builder.position_at_end(tail_bb);
|
|
|
|
|
2024-02-19 19:30:25 +08:00
|
|
|
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
|
2022-03-10 16:48:28 +08:00
|
|
|
if need_load {
|
2024-02-22 01:47:26 +08:00
|
|
|
call_stackrestore(ctx, stackptr);
|
2022-03-10 16:48:28 +08:00
|
|
|
}
|
|
|
|
Ok(Some(result))
|
2022-02-12 21:17:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-06 11:49:02 +08:00
|
|
|
pub fn attributes_writeback(
|
|
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
2022-03-25 22:42:01 +08:00
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
inner_resolver: &InnerResolver,
|
2023-12-11 15:04:35 +08:00
|
|
|
host_attributes: &PyObject,
|
2022-03-25 22:42:01 +08:00
|
|
|
) -> Result<(), String> {
|
|
|
|
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
2023-09-01 16:56:32 +08:00
|
|
|
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
2022-03-25 22:42:01 +08:00
|
|
|
let top_levels = ctx.top_level.definitions.read();
|
|
|
|
let globals = inner_resolver.global_value_ids.read();
|
|
|
|
let int32 = ctx.ctx.i32_type();
|
|
|
|
let zero = int32.const_zero();
|
|
|
|
let mut values = Vec::new();
|
|
|
|
let mut scratch_buffer = Vec::new();
|
2023-12-11 15:04:35 +08:00
|
|
|
for val in (*globals).values() {
|
2022-03-25 22:42:01 +08:00
|
|
|
let val = val.as_ref(py);
|
2024-06-12 14:45:03 +08:00
|
|
|
let ty = inner_resolver.get_obj_type(
|
|
|
|
py,
|
|
|
|
val,
|
|
|
|
&mut ctx.unifier,
|
|
|
|
&top_levels,
|
|
|
|
&ctx.primitives,
|
|
|
|
)?;
|
2022-03-25 22:42:01 +08:00
|
|
|
if let Err(ty) = ty {
|
2024-06-12 14:45:03 +08:00
|
|
|
return Ok(Err(ty));
|
2022-03-25 22:42:01 +08:00
|
|
|
}
|
|
|
|
let ty = ty.unwrap();
|
|
|
|
match &*ctx.unifier.get_ty(ty) {
|
2022-04-09 03:50:39 +08:00
|
|
|
TypeEnum::TObj { fields, obj_id, .. }
|
2024-03-27 10:36:02 +08:00
|
|
|
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
2022-04-09 03:50:39 +08:00
|
|
|
{
|
2022-03-25 22:42:01 +08:00
|
|
|
// we only care about primitive attributes
|
|
|
|
// for non-primitive attributes, they should be in another global
|
|
|
|
let mut attributes = Vec::new();
|
2022-04-10 01:02:52 +08:00
|
|
|
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
|
2023-12-11 15:04:35 +08:00
|
|
|
for (name, (field_ty, is_mutable)) in fields {
|
2022-03-25 22:42:01 +08:00
|
|
|
if !is_mutable {
|
2024-06-12 14:45:03 +08:00
|
|
|
continue;
|
2022-03-25 22:42:01 +08:00
|
|
|
}
|
|
|
|
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
|
|
|
attributes.push(name.to_string());
|
|
|
|
let index = ctx.get_attr_index(ty, *name);
|
2024-06-12 14:45:03 +08:00
|
|
|
values.push((
|
|
|
|
*field_ty,
|
|
|
|
ctx.build_gep_and_load(
|
|
|
|
obj.into_pointer_value(),
|
|
|
|
&[zero, int32.const_int(index as u64, false)],
|
|
|
|
None,
|
|
|
|
),
|
|
|
|
));
|
2022-03-25 22:42:01 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
if !attributes.is_empty() {
|
|
|
|
let pydict = PyDict::new(py);
|
|
|
|
pydict.set_item("obj", val)?;
|
|
|
|
pydict.set_item("fields", attributes)?;
|
|
|
|
host_attributes.append(pydict)?;
|
|
|
|
}
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
2022-03-25 22:42:01 +08:00
|
|
|
TypeEnum::TList { ty: elem_ty } => {
|
|
|
|
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
|
|
|
let pydict = PyDict::new(py);
|
|
|
|
pydict.set_item("obj", val)?;
|
|
|
|
host_attributes.append(pydict)?;
|
2024-06-12 14:45:03 +08:00
|
|
|
values.push((
|
|
|
|
ty,
|
|
|
|
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
|
|
|
|
));
|
2022-03-25 22:42:01 +08:00
|
|
|
}
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
2022-03-25 22:42:01 +08:00
|
|
|
_ => {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let fun = FunSignature {
|
2024-06-12 14:45:03 +08:00
|
|
|
args: values
|
|
|
|
.iter()
|
|
|
|
.enumerate()
|
|
|
|
.map(|(i, (ty, _))| FuncArg {
|
|
|
|
name: i.to_string().into(),
|
|
|
|
ty: *ty,
|
|
|
|
default_value: None,
|
|
|
|
})
|
|
|
|
.collect(),
|
2022-03-25 22:42:01 +08:00
|
|
|
ret: ctx.primitives.none,
|
2024-06-12 14:45:03 +08:00
|
|
|
vars: VarMap::default(),
|
2022-03-25 22:42:01 +08:00
|
|
|
};
|
2024-06-12 14:45:03 +08:00
|
|
|
let args: Vec<_> =
|
|
|
|
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
|
|
|
if let Err(e) =
|
2024-06-12 15:01:01 +08:00
|
|
|
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
|
2024-06-12 14:45:03 +08:00
|
|
|
{
|
2022-03-25 22:42:01 +08:00
|
|
|
return Ok(Err(e));
|
|
|
|
}
|
|
|
|
Ok(Ok(()))
|
2024-06-12 14:45:03 +08:00
|
|
|
})
|
|
|
|
.unwrap()?;
|
2022-03-25 22:42:01 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
2022-02-12 21:17:37 +08:00
|
|
|
pub fn rpc_codegen_callback() -> Arc<GenCall> {
|
|
|
|
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
|
|
|
|
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
|
|
|
|
})))
|
|
|
|
}
|