forked from M-Labs/nac3
David Mak
06092ad29b
These functions may not be invokable by the same set of parameters as some classes has associated states.
1591 lines
60 KiB
Rust
1591 lines
60 KiB
Rust
use std::{
|
|
collections::{hash_map::DefaultHasher, HashMap},
|
|
hash::{Hash, Hasher},
|
|
iter::once,
|
|
mem,
|
|
sync::Arc,
|
|
};
|
|
|
|
use itertools::Itertools;
|
|
use pyo3::{
|
|
types::{PyDict, PyList},
|
|
PyObject, PyResult, Python,
|
|
};
|
|
|
|
use nac3core::{
|
|
codegen::{
|
|
expr::{destructure_range, gen_call},
|
|
irrt::call_ndarray_calc_size,
|
|
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
|
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
|
types::NDArrayType,
|
|
values::{
|
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
|
|
RangeValue, UntypedArrayLikeAccessor,
|
|
},
|
|
CodeGenContext, CodeGenerator,
|
|
},
|
|
inkwell::{
|
|
context::Context,
|
|
module::Linkage,
|
|
types::{BasicType, IntType},
|
|
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
|
|
AddressSpace, IntPredicate, OptimizationLevel,
|
|
},
|
|
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
|
|
symbol_resolver::ValueEnum,
|
|
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
|
|
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
|
};
|
|
|
|
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
|
|
|
/// The parallelism mode within a block.
|
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
|
enum ParallelMode {
|
|
/// No parallelism is currently registered for this context.
|
|
None,
|
|
|
|
/// Legacy (or shallow) parallelism. Default before NAC3.
|
|
///
|
|
/// Each statement within the `with` block is treated as statements to be executed in parallel.
|
|
Legacy,
|
|
|
|
/// Deep parallelism. Default since NAC3.
|
|
///
|
|
/// Each function call within the `with` block (except those within a nested `sequential` block)
|
|
/// are treated to be executed in parallel.
|
|
Deep,
|
|
}
|
|
|
|
pub struct ArtiqCodeGenerator<'a> {
|
|
name: String,
|
|
|
|
/// The size of a `size_t` variable in bits.
|
|
size_t: u32,
|
|
|
|
/// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks.
|
|
name_counter: u32,
|
|
|
|
/// Variable for tracking the start of a `with parallel` block.
|
|
start: Option<Expr<Option<Type>>>,
|
|
|
|
/// Variable for tracking the end of a `with parallel` block.
|
|
end: Option<Expr<Option<Type>>>,
|
|
timeline: &'a (dyn TimeFns + Sync),
|
|
|
|
/// The [`ParallelMode`] of the current parallel context.
|
|
///
|
|
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
|
/// statement, which is used to determine when and how the timeline should be updated.
|
|
parallel_mode: ParallelMode,
|
|
}
|
|
|
|
impl<'a> ArtiqCodeGenerator<'a> {
|
|
pub fn new(
|
|
name: String,
|
|
size_t: u32,
|
|
timeline: &'a (dyn TimeFns + Sync),
|
|
) -> ArtiqCodeGenerator<'a> {
|
|
assert!(size_t == 32 || size_t == 64);
|
|
ArtiqCodeGenerator {
|
|
name,
|
|
size_t,
|
|
name_counter: 0,
|
|
start: None,
|
|
end: None,
|
|
timeline,
|
|
parallel_mode: ParallelMode::None,
|
|
}
|
|
}
|
|
|
|
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
|
|
/// position of the timeline to the initial timeline position before entering the `parallel`
|
|
/// block.
|
|
///
|
|
/// Direct-`parallel` block context refers to when the generator is generating statements whose
|
|
/// closest parent `with` statement is a `with parallel` block.
|
|
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
|
|
if let Some(start) = self.start.clone() {
|
|
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
start.custom.unwrap(),
|
|
)?;
|
|
self.timeline.emit_at_mu(ctx, start_val);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// If the generator is currently in a `parallel` block context, emits IR that updates the
|
|
/// maximum end position of the `parallel` block as specified by the timeline `end` value.
|
|
///
|
|
/// In general the `end` parameter should be set to `self.end` for updating the maximum end
|
|
/// position for the current `parallel` block. Other values can be passed in to update the
|
|
/// maximum end position for other `parallel` blocks.
|
|
///
|
|
/// `parallel`-block context refers to when the generator is generating statements within a
|
|
/// (possibly indirect) `parallel` block.
|
|
///
|
|
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
|
|
/// the end of the provided value name.
|
|
fn timeline_update_end_max(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
end: Option<Expr<Option<Type>>>,
|
|
store_name: Option<&str>,
|
|
) -> Result<(), String> {
|
|
if let Some(end) = end {
|
|
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
end.custom.unwrap(),
|
|
)?;
|
|
let now = self.timeline.emit_now_mu(ctx);
|
|
let max =
|
|
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
|
|
let end_store = self
|
|
.gen_store_target(
|
|
ctx,
|
|
&end,
|
|
store_name.map(|name| format!("{name}.addr")).as_deref(),
|
|
)?
|
|
.unwrap();
|
|
ctx.builder.build_store(end_store, max).unwrap();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
|
fn get_name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
|
if self.size_t == 32 {
|
|
ctx.i32_type()
|
|
} else {
|
|
ctx.i64_type()
|
|
}
|
|
}
|
|
|
|
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
stmts: I,
|
|
) -> Result<(), String>
|
|
where
|
|
Self: Sized,
|
|
{
|
|
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
|
|
// in the parallel block
|
|
if self.parallel_mode == ParallelMode::Legacy {
|
|
for stmt in stmts {
|
|
self.gen_stmt(ctx, stmt)?;
|
|
|
|
if ctx.is_terminated() {
|
|
break;
|
|
}
|
|
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
Ok(())
|
|
} else {
|
|
gen_block(self, ctx, stmts)
|
|
}
|
|
}
|
|
|
|
fn gen_call<'ctx>(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
|
let result = gen_call(self, ctx, obj, fun, params)?;
|
|
|
|
// Deep parallel emits timeline end-update/timeline-reset after each function call
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
fn gen_with(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
stmt: &Stmt<Option<Type>>,
|
|
) -> Result<(), String> {
|
|
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
|
|
|
|
if items.len() == 1 && items[0].optional_vars.is_none() {
|
|
let item = &items[0];
|
|
|
|
// Behavior of parallel and sequential:
|
|
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
|
// block will update the end variable to the maximum now_mu in the block.
|
|
// Each function call directly inside a parallel block will reset the timeline after
|
|
// execution. A parallel block within a sequential block (or not within any block) will
|
|
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
|
// be updated).
|
|
//
|
|
// Implementation: We track the start and end separately.
|
|
// - If there is a start variable, it indicates that we are directly inside a
|
|
// parallel block and we have to reset the timeline after every function call.
|
|
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
|
// parallel block, and we should update the max end value.
|
|
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
|
if id == &"parallel".into() || id == &"legacy_parallel".into() {
|
|
let old_start = self.start.take();
|
|
let old_end = self.end.take();
|
|
let old_parallel_mode = self.parallel_mode;
|
|
|
|
let now = if let Some(old_start) = &old_start {
|
|
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
old_start.custom.unwrap(),
|
|
)?
|
|
} else {
|
|
self.timeline.emit_now_mu(ctx)
|
|
};
|
|
|
|
// Emulate variable allocation, as we need to use the CodeGenContext
|
|
// HashMap to store our variable due to lifetime limitation
|
|
// Note: we should be able to store variables directly if generic
|
|
// associative type is used by limiting the lifetime of CodeGenerator to
|
|
// the LLVM Context.
|
|
// The name is guaranteed to be unique as users cannot use this as variable
|
|
// name.
|
|
self.start = old_start.clone().map_or_else(
|
|
|| {
|
|
let start = format!("with-{}-start", self.name_counter).into();
|
|
let start_expr = Located {
|
|
// location does not matter at this point
|
|
location: stmt.location,
|
|
node: ExprKind::Name { id: start, ctx: *name_ctx },
|
|
custom: Some(ctx.primitives.int64),
|
|
};
|
|
let start = self
|
|
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
|
|
.unwrap();
|
|
ctx.builder.build_store(start, now).unwrap();
|
|
Ok(Some(start_expr)) as Result<_, String>
|
|
},
|
|
|v| Ok(Some(v)),
|
|
)?;
|
|
let end = format!("with-{}-end", self.name_counter).into();
|
|
let end_expr = Located {
|
|
// location does not matter at this point
|
|
location: stmt.location,
|
|
node: ExprKind::Name { id: end, ctx: *name_ctx },
|
|
custom: Some(ctx.primitives.int64),
|
|
};
|
|
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
|
|
ctx.builder.build_store(end, now).unwrap();
|
|
self.end = Some(end_expr);
|
|
self.name_counter += 1;
|
|
self.parallel_mode = match id.to_string().as_str() {
|
|
"parallel" => ParallelMode::Deep,
|
|
"legacy_parallel" => ParallelMode::Legacy,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
self.gen_block(ctx, body.iter())?;
|
|
|
|
let current = ctx.builder.get_insert_block().unwrap();
|
|
|
|
// if the current block is terminated, move before the terminator
|
|
// we want to set the timeline before reaching the terminator
|
|
// TODO: This may be unsound if there are multiple exit paths in the
|
|
// block... e.g.
|
|
// if ...:
|
|
// return
|
|
// Perhaps we can fix this by using actual with block?
|
|
let reset_position = if let Some(terminator) = current.get_terminator() {
|
|
ctx.builder.position_before(&terminator);
|
|
true
|
|
} else {
|
|
false
|
|
};
|
|
|
|
// set duration
|
|
let end_expr = self.end.take().unwrap();
|
|
let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
end_expr.custom.unwrap(),
|
|
)?;
|
|
|
|
// inside a sequential block
|
|
if old_start.is_none() {
|
|
self.timeline.emit_at_mu(ctx, end_val);
|
|
}
|
|
|
|
// inside a parallel block, should update the outer max now_mu
|
|
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
|
|
|
|
self.parallel_mode = old_parallel_mode;
|
|
self.end = old_end;
|
|
self.start = old_start;
|
|
|
|
if reset_position {
|
|
ctx.builder.position_at_end(current);
|
|
}
|
|
|
|
return Ok(());
|
|
} else if id == &"sequential".into() {
|
|
// For deep parallel, temporarily take away start to avoid function calls in
|
|
// the block from resetting the timeline.
|
|
// This does not affect legacy parallel, as the timeline will be reset after
|
|
// this block finishes execution.
|
|
let start = self.start.take();
|
|
self.gen_block(ctx, body.iter())?;
|
|
self.start = start;
|
|
|
|
// Reset the timeline when we are exiting the sequential block
|
|
// Legacy parallel does not need this, since it will be reset after codegen
|
|
// for this statement is completed
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
|
|
// not parallel/sequential
|
|
gen_with(self, ctx, stmt)
|
|
}
|
|
}
|
|
|
|
fn gen_rpc_tag(
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
ty: Type,
|
|
buffer: &mut Vec<u8>,
|
|
) -> Result<(), String> {
|
|
use nac3core::typecheck::typedef::TypeEnum::*;
|
|
|
|
let int32 = ctx.primitives.int32;
|
|
let int64 = ctx.primitives.int64;
|
|
let float = ctx.primitives.float;
|
|
let bool = ctx.primitives.bool;
|
|
let str = ctx.primitives.str;
|
|
let none = ctx.primitives.none;
|
|
|
|
if ctx.unifier.unioned(ty, int32) {
|
|
buffer.push(b'i');
|
|
} else if ctx.unifier.unioned(ty, int64) {
|
|
buffer.push(b'I');
|
|
} else if ctx.unifier.unioned(ty, float) {
|
|
buffer.push(b'f');
|
|
} else if ctx.unifier.unioned(ty, bool) {
|
|
buffer.push(b'b');
|
|
} else if ctx.unifier.unioned(ty, str) {
|
|
buffer.push(b's');
|
|
} else if ctx.unifier.unioned(ty, none) {
|
|
buffer.push(b'n');
|
|
} else {
|
|
let ty_enum = ctx.unifier.get_ty(ty);
|
|
match &*ty_enum {
|
|
TTuple { ty, is_vararg_ctx: false } => {
|
|
buffer.push(b't');
|
|
buffer.push(ty.len() as u8);
|
|
for ty in ty {
|
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
|
}
|
|
}
|
|
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let ty = iter_type_vars(params).next().unwrap().ty;
|
|
|
|
buffer.push(b'l');
|
|
gen_rpc_tag(ctx, ty, buffer)?;
|
|
}
|
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
let ndarray_ndims = if let TLiteral { values, .. } =
|
|
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
|
{
|
|
if values.len() != 1 {
|
|
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
|
}
|
|
|
|
let value = values[0].clone();
|
|
u64::try_from(value.clone())
|
|
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
|
} else {
|
|
unreachable!()
|
|
};
|
|
assert!(
|
|
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
|
|
"Only NDArrays of sizes between 0 and 255 can be RPCed"
|
|
);
|
|
|
|
buffer.push(b'a');
|
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
|
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
|
}
|
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Formats an RPC argument to conform to the expected format required by `send_value`.
|
|
///
|
|
/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format.
|
|
fn format_rpc_arg<'ctx>(
|
|
generator: &mut dyn CodeGenerator,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
|
|
) -> PointerValue<'ctx> {
|
|
let llvm_i8 = ctx.ctx.i8_type();
|
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
|
|
let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
// NAC3: NDArray = { usize, usize*, T* }
|
|
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
let llvm_arg = NDArrayValue::from_pointer_value(
|
|
arg.into_pointer_value(),
|
|
llvm_elem_ty,
|
|
llvm_usize,
|
|
None,
|
|
);
|
|
|
|
let llvm_usize_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_arg.get_type().size_type().size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
let llvm_pdata_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
|
|
let dims_buf_sz =
|
|
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
|
|
|
let buffer_size =
|
|
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
|
|
|
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
|
|
|
call_memcpy_generic(
|
|
ctx,
|
|
buffer.base_ptr(ctx, generator),
|
|
llvm_arg.ptr_to_data(ctx),
|
|
llvm_pdata_sizeof,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
|
|
let pbuffer_dims_begin =
|
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
|
call_memcpy_generic(
|
|
ctx,
|
|
pbuffer_dims_begin,
|
|
llvm_arg.shape().base_ptr(ctx, generator),
|
|
dims_buf_sz,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
|
|
buffer.base_ptr(ctx, generator)
|
|
}
|
|
|
|
_ => {
|
|
let arg_slot = generator
|
|
.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))
|
|
.unwrap();
|
|
ctx.builder.build_store(arg_slot, arg).unwrap();
|
|
|
|
ctx.builder
|
|
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
|
|
.map(BasicValueEnum::into_pointer_value)
|
|
.unwrap()
|
|
}
|
|
};
|
|
|
|
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
|
|
|
|
arg_slot
|
|
}
|
|
|
|
/// Formats an RPC return value to conform to the expected format required by NAC3.
|
|
fn format_rpc_ret<'ctx>(
|
|
generator: &mut dyn CodeGenerator,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
ret_ty: Type,
|
|
) -> Option<BasicValueEnum<'ctx>> {
|
|
// -- receive value:
|
|
// T result = {
|
|
// void *ret_ptr = alloca(sizeof(T));
|
|
// void *ptr = ret_ptr;
|
|
// loop: int size = rpc_recv(ptr);
|
|
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
|
// if(size) { ptr = alloca(size); goto loop; }
|
|
// else *(T*)ret_ptr
|
|
// }
|
|
|
|
let llvm_i8 = ctx.ctx.i8_type();
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
|
|
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
|
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
|
});
|
|
|
|
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
|
|
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
|
|
return None;
|
|
}
|
|
|
|
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
|
let current_function = prehead_bb.get_parent().unwrap();
|
|
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
|
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
|
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
|
|
|
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
|
|
|
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
// Round `val` up to its modulo `power_of_two`
|
|
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
val: IntValue<'ctx>,
|
|
power_of_two: IntValue<'ctx>| {
|
|
debug_assert_eq!(
|
|
val.get_type().get_bit_width(),
|
|
power_of_two.get_type().get_bit_width()
|
|
);
|
|
|
|
let llvm_val_t = val.get_type();
|
|
|
|
let max_rem = ctx
|
|
.builder
|
|
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
|
|
.unwrap();
|
|
ctx.builder
|
|
.build_and(
|
|
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
|
ctx.builder.build_not(max_rem, "").unwrap(),
|
|
"",
|
|
)
|
|
.unwrap()
|
|
};
|
|
|
|
// Setup types
|
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
|
|
|
// Allocate the resulting ndarray
|
|
// A condition after format_rpc_ret ensures this will not be popped this off.
|
|
let ndarray = llvm_ret_ty.alloca(generator, ctx, Some("rpc.result"));
|
|
|
|
// Setup ndims
|
|
let ndims =
|
|
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
|
assert_eq!(values.len(), 1);
|
|
|
|
u64::try_from(values[0].clone()).unwrap()
|
|
} else {
|
|
unreachable!();
|
|
};
|
|
// Set `ndarray.ndims`
|
|
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
|
// Allocate `ndarray.shape` [size_t; ndims]
|
|
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
|
|
|
/*
|
|
ndarray now:
|
|
- .ndims: initialized
|
|
- .shape: allocated but uninitialized .shape
|
|
- .data: uninitialized
|
|
*/
|
|
|
|
let llvm_usize_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
|
.unwrap();
|
|
let llvm_pdata_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
let llvm_elem_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
|
.unwrap();
|
|
|
|
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
|
// (4 + 4 * ndims) bytes with 8-byte alignment
|
|
let sizeof_dims =
|
|
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
|
let unaligned_buffer_size =
|
|
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
|
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
|
|
|
let stackptr = call_stacksave(ctx, None);
|
|
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
|
let buffer = ctx
|
|
.builder
|
|
.build_array_alloca(
|
|
llvm_i8_8,
|
|
ctx.builder
|
|
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
|
.unwrap(),
|
|
"rpc.buffer",
|
|
)
|
|
.unwrap();
|
|
let buffer = ctx
|
|
.builder
|
|
.build_bit_cast(buffer, llvm_pi8, "")
|
|
.map(BasicValueEnum::into_pointer_value)
|
|
.unwrap();
|
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
|
|
|
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
|
//
|
|
// The returned value is the number of bytes for `ndarray.data`.
|
|
let ndarray_nbytes = ctx
|
|
.build_call_or_invoke(
|
|
rpc_recv,
|
|
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
|
"rpc.size.next",
|
|
)
|
|
.map(BasicValueEnum::into_int_value)
|
|
.unwrap();
|
|
|
|
// debug_assert(ndarray_nbytes > 0)
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
ctx.make_assert(
|
|
generator,
|
|
ctx.builder
|
|
.build_int_compare(
|
|
IntPredicate::UGT,
|
|
ndarray_nbytes,
|
|
ndarray_nbytes.get_type().const_zero(),
|
|
"",
|
|
)
|
|
.unwrap(),
|
|
"0:AssertionError",
|
|
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
|
[None, None, None],
|
|
ctx.current_loc,
|
|
);
|
|
}
|
|
|
|
// Copy shape from the buffer to `ndarray.shape`.
|
|
let pbuffer_dims =
|
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
|
|
|
call_memcpy_generic(
|
|
ctx,
|
|
ndarray.shape().base_ptr(ctx, generator),
|
|
pbuffer_dims,
|
|
sizeof_dims,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
// Restore stack from before allocation of buffer
|
|
call_stackrestore(ctx, stackptr);
|
|
|
|
// Allocate `ndarray.data`.
|
|
// `ndarray.shape` must be initialized beforehand in this implementation
|
|
// (for ndarray.create_data() to know how many elements to allocate)
|
|
let num_elements =
|
|
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
|
|
|
|
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
let sizeof_data =
|
|
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
|
|
|
ctx.make_assert(
|
|
generator,
|
|
ctx.builder.build_int_compare(IntPredicate::UGE,
|
|
sizeof_data,
|
|
ndarray_nbytes,
|
|
"",
|
|
).unwrap(),
|
|
"0:AssertionError",
|
|
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
|
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
|
ctx.current_loc,
|
|
);
|
|
}
|
|
|
|
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
|
|
|
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
|
let ndarray_data_i8 =
|
|
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
|
|
|
// NOTE: Currently on `prehead_bb`
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
|
|
ctx.builder.position_at_end(head_bb);
|
|
|
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
|
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
|
|
|
let alloc_size = ctx
|
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
|
.map(BasicValueEnum::into_int_value)
|
|
.unwrap();
|
|
|
|
let is_done = ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
|
.unwrap();
|
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(alloc_bb);
|
|
// Align the allocation to sizeof(T)
|
|
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
|
let alloc_ptr = ctx
|
|
.builder
|
|
.build_array_alloca(
|
|
llvm_elem_ty,
|
|
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
|
"rpc.alloc",
|
|
)
|
|
.unwrap();
|
|
let alloc_ptr =
|
|
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(tail_bb);
|
|
ndarray.as_base_value().into()
|
|
}
|
|
|
|
_ => {
|
|
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
|
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
ctx.builder.position_at_end(head_bb);
|
|
|
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
|
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
|
let alloc_size = ctx
|
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
|
.unwrap()
|
|
.into_int_value();
|
|
let is_done = ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
|
.unwrap();
|
|
|
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
|
ctx.builder.position_at_end(alloc_bb);
|
|
|
|
let alloc_ptr =
|
|
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
|
let alloc_ptr =
|
|
ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(tail_bb);
|
|
ctx.builder.build_load(slot, "rpc.result").unwrap()
|
|
}
|
|
};
|
|
|
|
Some(result)
|
|
}
|
|
|
|
fn rpc_codegen_callback_fn<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
generator: &mut dyn CodeGenerator,
|
|
is_async: bool,
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
|
let int8 = ctx.ctx.i8_type();
|
|
let int32 = ctx.ctx.i32_type();
|
|
let size_type = generator.get_size_type(ctx.ctx);
|
|
let ptr_type = int8.ptr_type(AddressSpace::default());
|
|
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
|
|
|
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
|
// -- setup rpc tags
|
|
let mut tag = Vec::new();
|
|
if obj.is_some() {
|
|
tag.push(b'O');
|
|
}
|
|
for arg in &fun.0.args {
|
|
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
|
}
|
|
tag.push(b':');
|
|
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
|
|
|
let mut hasher = DefaultHasher::new();
|
|
tag.hash(&mut hasher);
|
|
let hash = format!("{}", hasher.finish());
|
|
|
|
let tag_ptr = ctx
|
|
.module
|
|
.get_global(hash.as_str())
|
|
.unwrap_or_else(|| {
|
|
let tag_arr_ptr = ctx.module.add_global(
|
|
int8.array_type(tag.len() as u32),
|
|
None,
|
|
format!("tagptr{}", fun.1 .0).as_str(),
|
|
);
|
|
tag_arr_ptr.set_initializer(&int8.const_array(
|
|
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
|
));
|
|
tag_arr_ptr.set_linkage(Linkage::Private);
|
|
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
|
tag_ptr.set_linkage(Linkage::Private);
|
|
tag_ptr.set_initializer(&ctx.ctx.const_struct(
|
|
&[
|
|
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
|
|
size_type.const_int(tag.len() as u64, false).into(),
|
|
],
|
|
false,
|
|
));
|
|
tag_ptr
|
|
})
|
|
.as_pointer_value();
|
|
|
|
let arg_length = args.len() + usize::from(obj.is_some());
|
|
|
|
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
|
let args_ptr = ctx
|
|
.builder
|
|
.build_array_alloca(
|
|
ptr_type,
|
|
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
|
"argptr",
|
|
)
|
|
.unwrap();
|
|
|
|
// -- rpc args handling
|
|
let mut keys = fun.0.args.clone();
|
|
let mut mapping = HashMap::new();
|
|
for (key, value) in args {
|
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
|
}
|
|
// default value handling
|
|
for k in keys {
|
|
mapping
|
|
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
|
}
|
|
// reorder the parameters
|
|
let mut real_params = fun
|
|
.0
|
|
.args
|
|
.iter()
|
|
.map(|arg| {
|
|
mapping
|
|
.remove(&arg.name)
|
|
.unwrap()
|
|
.to_basic_value_enum(ctx, generator, arg.ty)
|
|
.map(|llvm_val| (llvm_val, arg.ty))
|
|
})
|
|
.collect::<Result<Vec<(_, _)>, _>>()?;
|
|
if let Some(obj) = obj {
|
|
if let ValueEnum::Static(obj_val) = obj.1 {
|
|
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
|
} else {
|
|
// should be an error here...
|
|
panic!("only host object is allowed");
|
|
}
|
|
}
|
|
|
|
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
|
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
|
let arg_ptr = unsafe {
|
|
ctx.builder.build_gep(
|
|
args_ptr,
|
|
&[int32.const_int(i as u64, false)],
|
|
&format!("rpc.arg{i}"),
|
|
)
|
|
}
|
|
.unwrap();
|
|
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
|
}
|
|
|
|
// call
|
|
if is_async {
|
|
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
|
ctx.module.add_function(
|
|
"rpc_send_async",
|
|
ctx.ctx.void_type().fn_type(
|
|
&[
|
|
int32.into(),
|
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
],
|
|
false,
|
|
),
|
|
None,
|
|
)
|
|
});
|
|
ctx.builder
|
|
.build_call(
|
|
rpc_send_async,
|
|
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
|
"rpc.send",
|
|
)
|
|
.unwrap();
|
|
} else {
|
|
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
|
ctx.module.add_function(
|
|
"rpc_send",
|
|
ctx.ctx.void_type().fn_type(
|
|
&[
|
|
int32.into(),
|
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
],
|
|
false,
|
|
),
|
|
None,
|
|
)
|
|
});
|
|
ctx.builder
|
|
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
|
.unwrap();
|
|
}
|
|
|
|
// reclaim stack space used by arguments
|
|
call_stackrestore(ctx, stackptr);
|
|
|
|
if is_async {
|
|
// async RPCs do not return any values
|
|
Ok(None)
|
|
} else {
|
|
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
|
|
|
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
|
// An RPC returning an NDArray would not touch here.
|
|
call_stackrestore(ctx, stackptr);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
pub fn attributes_writeback<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
inner_resolver: &InnerResolver,
|
|
host_attributes: &PyObject,
|
|
return_obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
) -> Result<(), String> {
|
|
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
|
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
|
let top_levels = ctx.top_level.definitions.read();
|
|
let globals = inner_resolver.global_value_ids.read();
|
|
let int32 = ctx.ctx.i32_type();
|
|
let zero = int32.const_zero();
|
|
let mut values = Vec::new();
|
|
let mut scratch_buffer = Vec::new();
|
|
|
|
if let Some((ty, obj)) = return_obj {
|
|
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
|
|
}
|
|
|
|
for val in (*globals).values() {
|
|
let val = val.as_ref(py);
|
|
let ty = inner_resolver.get_obj_type(
|
|
py,
|
|
val,
|
|
&mut ctx.unifier,
|
|
&top_levels,
|
|
&ctx.primitives,
|
|
)?;
|
|
if let Err(ty) = ty {
|
|
return Ok(Err(ty));
|
|
}
|
|
let ty = ty.unwrap();
|
|
match &*ctx.unifier.get_ty(ty) {
|
|
TypeEnum::TObj { fields, obj_id, .. }
|
|
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
|
{
|
|
// we only care about primitive attributes
|
|
// for non-primitive attributes, they should be in another global
|
|
let mut attributes = Vec::new();
|
|
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
|
|
for (name, (field_ty, is_mutable)) in fields {
|
|
if !is_mutable {
|
|
continue;
|
|
}
|
|
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
|
attributes.push(name.to_string());
|
|
let (index, _) = ctx.get_attr_index(ty, *name);
|
|
values.push((
|
|
*field_ty,
|
|
ctx.build_gep_and_load(
|
|
obj.into_pointer_value(),
|
|
&[zero, int32.const_int(index as u64, false)],
|
|
None,
|
|
),
|
|
));
|
|
}
|
|
}
|
|
if !attributes.is_empty() {
|
|
let pydict = PyDict::new(py);
|
|
pydict.set_item("obj", val)?;
|
|
pydict.set_item("fields", attributes)?;
|
|
host_attributes.append(pydict)?;
|
|
}
|
|
}
|
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let elem_ty = iter_type_vars(params).next().unwrap().ty;
|
|
|
|
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
|
|
let pydict = PyDict::new(py);
|
|
pydict.set_item("obj", val)?;
|
|
host_attributes.append(pydict)?;
|
|
values.push((
|
|
ty,
|
|
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
|
|
));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
let fun = FunSignature {
|
|
args: values
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, (ty, _))| FuncArg {
|
|
name: i.to_string().into(),
|
|
ty: *ty,
|
|
default_value: None,
|
|
is_vararg: false,
|
|
})
|
|
.collect(),
|
|
ret: ctx.primitives.none,
|
|
vars: VarMap::default(),
|
|
};
|
|
let args: Vec<_> =
|
|
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
|
if let Err(e) =
|
|
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, true)
|
|
{
|
|
return Ok(Err(e));
|
|
}
|
|
Ok(Ok(()))
|
|
})
|
|
.unwrap()?;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
|
|
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
|
rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
|
|
})))
|
|
}
|
|
|
|
/// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with
|
|
/// [`llvm_usize`] as its native word size.
|
|
///
|
|
/// Note that, similar to format constants in `<inttypes.h>`, these constants need to be prepended
|
|
/// with `%`.
|
|
#[must_use]
|
|
fn get_fprintf_format_constant<'ctx>(
|
|
llvm_usize: IntType<'ctx>,
|
|
llvm_int_t: IntType<'ctx>,
|
|
is_unsigned: bool,
|
|
) -> String {
|
|
debug_assert!(matches!(llvm_usize.get_bit_width(), 8 | 16 | 32 | 64));
|
|
|
|
let conv_spec = if is_unsigned { 'u' } else { 'd' };
|
|
|
|
// https://en.cppreference.com/w/c/language/arithmetic_types
|
|
// Note that NAC3 does **not** support LP32 and LLP64 configurations
|
|
match llvm_int_t.get_bit_width() {
|
|
8 => format!("hh{conv_spec}"),
|
|
16 => format!("h{conv_spec}"),
|
|
32 => conv_spec.to_string(),
|
|
64 => format!("{}{conv_spec}", if llvm_usize.get_bit_width() == 64 { "l" } else { "ll" }),
|
|
_ => todo!(
|
|
"Not yet implemented for i{} on {}-bit platform",
|
|
llvm_int_t.get_bit_width(),
|
|
llvm_usize.get_bit_width()
|
|
),
|
|
}
|
|
}
|
|
|
|
/// Prints one or more `values` to `core_log` or `rtio_log`.
|
|
///
|
|
/// * `separator` - The separator between multiple values.
|
|
/// * `suffix` - String to terminate the printed string, if any.
|
|
/// * `as_repr` - Whether the `repr()` output of values instead of `str()`.
|
|
/// * `as_rtio` - Whether to print to `rtio_log` instead of `core_log`.
|
|
fn polymorphic_print<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
values: &[(Type, ValueEnum<'ctx>)],
|
|
separator: &str,
|
|
suffix: Option<&str>,
|
|
as_repr: bool,
|
|
as_rtio: bool,
|
|
) -> Result<(), String> {
|
|
let printf = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
fmt: String,
|
|
args: Vec<BasicValueEnum<'ctx>>| {
|
|
debug_assert!(!fmt.is_empty());
|
|
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
|
|
|
|
let fn_name = if as_rtio { "rtio_log" } else { "core_log" };
|
|
let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
|
|
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
|
let fn_t = if as_rtio {
|
|
let llvm_void = ctx.ctx.void_type();
|
|
llvm_void.fn_type(&[llvm_pi8.into()], true)
|
|
} else {
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
llvm_i32.fn_type(&[llvm_pi8.into()], true)
|
|
};
|
|
ctx.module.add_function(fn_name, fn_t, None)
|
|
});
|
|
|
|
let fmt = ctx.gen_string(generator, fmt);
|
|
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
|
|
ctx.builder
|
|
.build_call(
|
|
print_fn,
|
|
&once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(),
|
|
"",
|
|
)
|
|
.unwrap();
|
|
};
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
let llvm_i64 = ctx.ctx.i64_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
let suffix = suffix.unwrap_or_default();
|
|
|
|
let mut fmt = String::new();
|
|
let mut args = Vec::new();
|
|
|
|
let flush = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
fmt: &mut String,
|
|
args: &mut Vec<BasicValueEnum<'ctx>>| {
|
|
if !fmt.is_empty() {
|
|
fmt.push('\0');
|
|
printf(ctx, generator, mem::take(fmt), mem::take(args));
|
|
}
|
|
};
|
|
|
|
for (ty, value) in values {
|
|
let ty = *ty;
|
|
let value = value.clone().to_basic_value_enum(ctx, generator, ty).unwrap();
|
|
|
|
if !fmt.is_empty() {
|
|
fmt.push_str(separator);
|
|
}
|
|
|
|
match &*ctx.unifier.get_ty_immutable(ty) {
|
|
TypeEnum::TTuple { ty: tys, is_vararg_ctx: false } => {
|
|
let pvalue = {
|
|
let pvalue = generator.gen_var_alloc(ctx, value.get_type(), None).unwrap();
|
|
ctx.builder.build_store(pvalue, value).unwrap();
|
|
pvalue
|
|
};
|
|
|
|
fmt.push('(');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let tuple_vals = tys
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, ty)| {
|
|
(*ty, {
|
|
let pfield =
|
|
ctx.builder.build_struct_gep(pvalue, i as u32, "").unwrap();
|
|
|
|
ValueEnum::from(ctx.builder.build_load(pfield, "").unwrap())
|
|
})
|
|
})
|
|
.collect_vec();
|
|
|
|
polymorphic_print(ctx, generator, &tuple_vals, ", ", None, true, as_rtio)?;
|
|
|
|
if tuple_vals.len() == 1 {
|
|
fmt.push_str(",)");
|
|
} else {
|
|
fmt.push(')');
|
|
}
|
|
}
|
|
|
|
TypeEnum::TFunc { .. } => todo!(),
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::None.id() => {
|
|
fmt.push_str("None");
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Bool.id() => {
|
|
fmt.push_str("%.*s");
|
|
|
|
let true_str = ctx.gen_string(generator, "True");
|
|
let true_data =
|
|
unsafe { true_str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let true_len = unsafe { true_str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
let false_str = ctx.gen_string(generator, "False");
|
|
let false_data =
|
|
unsafe { false_str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let false_len =
|
|
unsafe { false_str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
|
|
let bool_val = generator.bool_to_i1(ctx, value.into_int_value());
|
|
|
|
args.extend([
|
|
ctx.builder.build_select(bool_val, true_len, false_len, "").unwrap(),
|
|
ctx.builder.build_select(bool_val, true_data, false_data, "").unwrap(),
|
|
]);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. }
|
|
if *obj_id == PrimDef::Int32.id()
|
|
|| *obj_id == PrimDef::Int64.id()
|
|
|| *obj_id == PrimDef::UInt32.id()
|
|
|| *obj_id == PrimDef::UInt64.id() =>
|
|
{
|
|
let is_unsigned =
|
|
*obj_id == PrimDef::UInt32.id() || *obj_id == PrimDef::UInt64.id();
|
|
|
|
let llvm_int_t = value.get_type().into_int_type();
|
|
debug_assert!(matches!(llvm_usize.get_bit_width(), 32 | 64));
|
|
debug_assert!(matches!(llvm_int_t.get_bit_width(), 32 | 64));
|
|
|
|
let fmt_spec = format!(
|
|
"%{}",
|
|
get_fprintf_format_constant(llvm_usize, llvm_int_t, is_unsigned)
|
|
);
|
|
|
|
fmt.push_str(fmt_spec.as_str());
|
|
args.push(value);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Float.id() => {
|
|
fmt.push_str("%g");
|
|
args.push(value);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Str.id() => {
|
|
if as_repr {
|
|
fmt.push_str("\"%.*s\"");
|
|
} else {
|
|
fmt.push_str("%.*s");
|
|
}
|
|
|
|
let str = value.into_struct_value();
|
|
let str_data = unsafe { str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let str_len = unsafe { str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
|
|
args.extend(&[str_len.into(), str_data.into()]);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let elem_ty = *params.iter().next().unwrap().1;
|
|
|
|
fmt.push('[');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val =
|
|
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
|
|
let len = val.load_size(ctx, None);
|
|
let last =
|
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
|
|
|
gen_for_callback_incrementing(
|
|
generator,
|
|
ctx,
|
|
None,
|
|
llvm_usize.const_zero(),
|
|
(len, false),
|
|
|generator, ctx, _, i| {
|
|
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(elem_ty, elem.into())],
|
|
"",
|
|
None,
|
|
true,
|
|
as_rtio,
|
|
)?;
|
|
|
|
gen_if_callback(
|
|
generator,
|
|
ctx,
|
|
|_, ctx| {
|
|
Ok(ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::ULT, i, last, "")
|
|
.unwrap())
|
|
},
|
|
|generator, ctx| {
|
|
printf(ctx, generator, ", \0".into(), Vec::default());
|
|
|
|
Ok(())
|
|
},
|
|
|_, _| Ok(()),
|
|
)?;
|
|
|
|
Ok(())
|
|
},
|
|
llvm_usize.const_int(1, false),
|
|
)?;
|
|
|
|
fmt.push(']');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
fmt.push_str("array([");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val = NDArrayValue::from_pointer_value(
|
|
value.into_pointer_value(),
|
|
llvm_elem_ty,
|
|
llvm_usize,
|
|
None,
|
|
);
|
|
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
|
|
let last =
|
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
|
|
|
gen_for_callback_incrementing(
|
|
generator,
|
|
ctx,
|
|
None,
|
|
llvm_usize.const_zero(),
|
|
(len, false),
|
|
|generator, ctx, _, i| {
|
|
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(elem_ty, elem.into())],
|
|
"",
|
|
None,
|
|
true,
|
|
as_rtio,
|
|
)?;
|
|
|
|
gen_if_callback(
|
|
generator,
|
|
ctx,
|
|
|_, ctx| {
|
|
Ok(ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::ULT, i, last, "")
|
|
.unwrap())
|
|
},
|
|
|generator, ctx| {
|
|
printf(ctx, generator, ", \0".into(), Vec::default());
|
|
|
|
Ok(())
|
|
},
|
|
|_, _| Ok(()),
|
|
)?;
|
|
|
|
Ok(())
|
|
},
|
|
llvm_usize.const_int(1, false),
|
|
)?;
|
|
|
|
fmt.push_str(")]");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Range.id() => {
|
|
fmt.push_str("range(");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val = RangeValue::from_pointer_value(value.into_pointer_value(), None);
|
|
|
|
let (start, stop, step) = destructure_range(ctx, val);
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[
|
|
(ctx.primitives.int32, start.into()),
|
|
(ctx.primitives.int32, stop.into()),
|
|
(ctx.primitives.int32, step.into()),
|
|
],
|
|
", ",
|
|
None,
|
|
false,
|
|
as_rtio,
|
|
)?;
|
|
|
|
fmt.push(')');
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Exception.id() => {
|
|
let fmt_str = format!(
|
|
"%{}(%{}, %{1:}, %{1:})",
|
|
get_fprintf_format_constant(llvm_usize, llvm_i32, false),
|
|
get_fprintf_format_constant(llvm_usize, llvm_i64, false),
|
|
);
|
|
|
|
let exn = value.into_pointer_value();
|
|
let name = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param0 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(6, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param1 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(7, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param2 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(8, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
|
|
fmt.push_str(fmt_str.as_str());
|
|
args.extend_from_slice(&[name.into(), param0.into(), param1.into(), param2.into()]);
|
|
}
|
|
|
|
_ => unreachable!(
|
|
"Unsupported object type for polymorphic_print: {}",
|
|
ctx.unifier.stringify(ty)
|
|
),
|
|
}
|
|
}
|
|
|
|
fmt.push_str(suffix);
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Invokes the `core_log` intrinsic function.
|
|
pub fn call_core_log_impl<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
arg: (Type, BasicValueEnum<'ctx>),
|
|
) -> Result<(), String> {
|
|
let (arg_ty, arg_val) = arg;
|
|
|
|
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\n"), false, false)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Invokes the `rtio_log` intrinsic function.
|
|
pub fn call_rtio_log_impl<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
channel: StructValue<'ctx>,
|
|
arg: (Type, BasicValueEnum<'ctx>),
|
|
) -> Result<(), String> {
|
|
let (arg_ty, arg_val) = arg;
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(ctx.primitives.str, channel.into())],
|
|
" ",
|
|
Some("\x1E"),
|
|
false,
|
|
true,
|
|
)?;
|
|
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\x1D"), false, true)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Generates a call to `core_log`.
|
|
pub fn gen_core_log<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
generator: &mut dyn CodeGenerator,
|
|
) -> Result<(), String> {
|
|
assert!(obj.is_none());
|
|
assert_eq!(args.len(), 1);
|
|
|
|
let value_ty = fun.0.args[0].ty;
|
|
let value_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
|
|
|
|
call_core_log_impl(ctx, generator, (value_ty, value_arg))
|
|
}
|
|
|
|
/// Generates a call to `rtio_log`.
|
|
pub fn gen_rtio_log<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
generator: &mut dyn CodeGenerator,
|
|
) -> Result<(), String> {
|
|
assert!(obj.is_none());
|
|
assert_eq!(args.len(), 2);
|
|
|
|
let channel_ty = fun.0.args[0].ty;
|
|
assert!(ctx.unifier.unioned(channel_ty, ctx.primitives.str));
|
|
let channel_arg =
|
|
args[0].1.clone().to_basic_value_enum(ctx, generator, channel_ty)?.into_struct_value();
|
|
let value_ty = fun.0.args[1].ty;
|
|
let value_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
|
|
|
|
call_rtio_log_impl(ctx, generator, channel_arg, (value_ty, value_arg))
|
|
}
|