forked from M-Labs/nac3
David Mak
08a7d01a13
Temporarily disable linalg ndarray tests as they are not ported to work with strided-ndarray.
1565 lines
59 KiB
Rust
1565 lines
59 KiB
Rust
use std::{
|
|
collections::{hash_map::DefaultHasher, HashMap},
|
|
hash::{Hash, Hasher},
|
|
iter::once,
|
|
mem,
|
|
sync::Arc,
|
|
};
|
|
|
|
use itertools::Itertools;
|
|
use pyo3::{
|
|
types::{PyDict, PyList},
|
|
PyObject, PyResult, Python,
|
|
};
|
|
|
|
use nac3core::{
|
|
codegen::{
|
|
expr::{destructure_range, gen_call},
|
|
irrt::ndarray::call_ndarray_calc_size,
|
|
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
|
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
|
type_aligned_alloca,
|
|
types::ndarray::NDArrayType,
|
|
values::{
|
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue,
|
|
UntypedArrayLikeAccessor,
|
|
},
|
|
CodeGenContext, CodeGenerator,
|
|
},
|
|
inkwell::{
|
|
context::Context,
|
|
module::Linkage,
|
|
types::{BasicType, IntType},
|
|
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
|
|
AddressSpace, IntPredicate, OptimizationLevel,
|
|
},
|
|
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
|
|
symbol_resolver::ValueEnum,
|
|
toplevel::{
|
|
helper::{extract_ndims, PrimDef},
|
|
numpy::unpack_ndarray_var_tys,
|
|
DefinitionId, GenCall,
|
|
},
|
|
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
|
};
|
|
|
|
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
|
|
|
/// The parallelism mode within a block.
|
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
|
enum ParallelMode {
|
|
/// No parallelism is currently registered for this context.
|
|
None,
|
|
|
|
/// Legacy (or shallow) parallelism. Default before NAC3.
|
|
///
|
|
/// Each statement within the `with` block is treated as statements to be executed in parallel.
|
|
Legacy,
|
|
|
|
/// Deep parallelism. Default since NAC3.
|
|
///
|
|
/// Each function call within the `with` block (except those within a nested `sequential` block)
|
|
/// are treated to be executed in parallel.
|
|
Deep,
|
|
}
|
|
|
|
pub struct ArtiqCodeGenerator<'a> {
|
|
name: String,
|
|
|
|
/// The size of a `size_t` variable in bits.
|
|
size_t: u32,
|
|
|
|
/// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks.
|
|
name_counter: u32,
|
|
|
|
/// Variable for tracking the start of a `with parallel` block.
|
|
start: Option<Expr<Option<Type>>>,
|
|
|
|
/// Variable for tracking the end of a `with parallel` block.
|
|
end: Option<Expr<Option<Type>>>,
|
|
timeline: &'a (dyn TimeFns + Sync),
|
|
|
|
/// The [`ParallelMode`] of the current parallel context.
|
|
///
|
|
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
|
/// statement, which is used to determine when and how the timeline should be updated.
|
|
parallel_mode: ParallelMode,
|
|
}
|
|
|
|
impl<'a> ArtiqCodeGenerator<'a> {
|
|
pub fn new(
|
|
name: String,
|
|
size_t: u32,
|
|
timeline: &'a (dyn TimeFns + Sync),
|
|
) -> ArtiqCodeGenerator<'a> {
|
|
assert!(size_t == 32 || size_t == 64);
|
|
ArtiqCodeGenerator {
|
|
name,
|
|
size_t,
|
|
name_counter: 0,
|
|
start: None,
|
|
end: None,
|
|
timeline,
|
|
parallel_mode: ParallelMode::None,
|
|
}
|
|
}
|
|
|
|
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
|
|
/// position of the timeline to the initial timeline position before entering the `parallel`
|
|
/// block.
|
|
///
|
|
/// Direct-`parallel` block context refers to when the generator is generating statements whose
|
|
/// closest parent `with` statement is a `with parallel` block.
|
|
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
|
|
if let Some(start) = self.start.clone() {
|
|
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
start.custom.unwrap(),
|
|
)?;
|
|
self.timeline.emit_at_mu(ctx, start_val);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// If the generator is currently in a `parallel` block context, emits IR that updates the
|
|
/// maximum end position of the `parallel` block as specified by the timeline `end` value.
|
|
///
|
|
/// In general the `end` parameter should be set to `self.end` for updating the maximum end
|
|
/// position for the current `parallel` block. Other values can be passed in to update the
|
|
/// maximum end position for other `parallel` blocks.
|
|
///
|
|
/// `parallel`-block context refers to when the generator is generating statements within a
|
|
/// (possibly indirect) `parallel` block.
|
|
///
|
|
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
|
|
/// the end of the provided value name.
|
|
fn timeline_update_end_max(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
end: Option<Expr<Option<Type>>>,
|
|
store_name: Option<&str>,
|
|
) -> Result<(), String> {
|
|
if let Some(end) = end {
|
|
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
end.custom.unwrap(),
|
|
)?;
|
|
let now = self.timeline.emit_now_mu(ctx);
|
|
let max =
|
|
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
|
|
let end_store = self
|
|
.gen_store_target(
|
|
ctx,
|
|
&end,
|
|
store_name.map(|name| format!("{name}.addr")).as_deref(),
|
|
)?
|
|
.unwrap();
|
|
ctx.builder.build_store(end_store, max).unwrap();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
|
fn get_name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
|
if self.size_t == 32 {
|
|
ctx.i32_type()
|
|
} else {
|
|
ctx.i64_type()
|
|
}
|
|
}
|
|
|
|
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
stmts: I,
|
|
) -> Result<(), String>
|
|
where
|
|
Self: Sized,
|
|
{
|
|
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
|
|
// in the parallel block
|
|
if self.parallel_mode == ParallelMode::Legacy {
|
|
for stmt in stmts {
|
|
self.gen_stmt(ctx, stmt)?;
|
|
|
|
if ctx.is_terminated() {
|
|
break;
|
|
}
|
|
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
Ok(())
|
|
} else {
|
|
gen_block(self, ctx, stmts)
|
|
}
|
|
}
|
|
|
|
fn gen_call<'ctx>(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
|
let result = gen_call(self, ctx, obj, fun, params)?;
|
|
|
|
// Deep parallel emits timeline end-update/timeline-reset after each function call
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
fn gen_with(
|
|
&mut self,
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
stmt: &Stmt<Option<Type>>,
|
|
) -> Result<(), String> {
|
|
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
|
|
|
|
if items.len() == 1 && items[0].optional_vars.is_none() {
|
|
let item = &items[0];
|
|
|
|
// Behavior of parallel and sequential:
|
|
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
|
// block will update the end variable to the maximum now_mu in the block.
|
|
// Each function call directly inside a parallel block will reset the timeline after
|
|
// execution. A parallel block within a sequential block (or not within any block) will
|
|
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
|
// be updated).
|
|
//
|
|
// Implementation: We track the start and end separately.
|
|
// - If there is a start variable, it indicates that we are directly inside a
|
|
// parallel block and we have to reset the timeline after every function call.
|
|
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
|
// parallel block, and we should update the max end value.
|
|
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
|
if id == &"parallel".into() || id == &"legacy_parallel".into() {
|
|
let old_start = self.start.take();
|
|
let old_end = self.end.take();
|
|
let old_parallel_mode = self.parallel_mode;
|
|
|
|
let now = if let Some(old_start) = &old_start {
|
|
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
old_start.custom.unwrap(),
|
|
)?
|
|
} else {
|
|
self.timeline.emit_now_mu(ctx)
|
|
};
|
|
|
|
// Emulate variable allocation, as we need to use the CodeGenContext
|
|
// HashMap to store our variable due to lifetime limitation
|
|
// Note: we should be able to store variables directly if generic
|
|
// associative type is used by limiting the lifetime of CodeGenerator to
|
|
// the LLVM Context.
|
|
// The name is guaranteed to be unique as users cannot use this as variable
|
|
// name.
|
|
self.start = old_start.clone().map_or_else(
|
|
|| {
|
|
let start = format!("with-{}-start", self.name_counter).into();
|
|
let start_expr = Located {
|
|
// location does not matter at this point
|
|
location: stmt.location,
|
|
node: ExprKind::Name { id: start, ctx: *name_ctx },
|
|
custom: Some(ctx.primitives.int64),
|
|
};
|
|
let start = self
|
|
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
|
|
.unwrap();
|
|
ctx.builder.build_store(start, now).unwrap();
|
|
Ok(Some(start_expr)) as Result<_, String>
|
|
},
|
|
|v| Ok(Some(v)),
|
|
)?;
|
|
let end = format!("with-{}-end", self.name_counter).into();
|
|
let end_expr = Located {
|
|
// location does not matter at this point
|
|
location: stmt.location,
|
|
node: ExprKind::Name { id: end, ctx: *name_ctx },
|
|
custom: Some(ctx.primitives.int64),
|
|
};
|
|
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
|
|
ctx.builder.build_store(end, now).unwrap();
|
|
self.end = Some(end_expr);
|
|
self.name_counter += 1;
|
|
self.parallel_mode = match id.to_string().as_str() {
|
|
"parallel" => ParallelMode::Deep,
|
|
"legacy_parallel" => ParallelMode::Legacy,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
self.gen_block(ctx, body.iter())?;
|
|
|
|
let current = ctx.builder.get_insert_block().unwrap();
|
|
|
|
// if the current block is terminated, move before the terminator
|
|
// we want to set the timeline before reaching the terminator
|
|
// TODO: This may be unsound if there are multiple exit paths in the
|
|
// block... e.g.
|
|
// if ...:
|
|
// return
|
|
// Perhaps we can fix this by using actual with block?
|
|
let reset_position = if let Some(terminator) = current.get_terminator() {
|
|
ctx.builder.position_before(&terminator);
|
|
true
|
|
} else {
|
|
false
|
|
};
|
|
|
|
// set duration
|
|
let end_expr = self.end.take().unwrap();
|
|
let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
|
|
ctx,
|
|
self,
|
|
end_expr.custom.unwrap(),
|
|
)?;
|
|
|
|
// inside a sequential block
|
|
if old_start.is_none() {
|
|
self.timeline.emit_at_mu(ctx, end_val);
|
|
}
|
|
|
|
// inside a parallel block, should update the outer max now_mu
|
|
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
|
|
|
|
self.parallel_mode = old_parallel_mode;
|
|
self.end = old_end;
|
|
self.start = old_start;
|
|
|
|
if reset_position {
|
|
ctx.builder.position_at_end(current);
|
|
}
|
|
|
|
return Ok(());
|
|
} else if id == &"sequential".into() {
|
|
// For deep parallel, temporarily take away start to avoid function calls in
|
|
// the block from resetting the timeline.
|
|
// This does not affect legacy parallel, as the timeline will be reset after
|
|
// this block finishes execution.
|
|
let start = self.start.take();
|
|
self.gen_block(ctx, body.iter())?;
|
|
self.start = start;
|
|
|
|
// Reset the timeline when we are exiting the sequential block
|
|
// Legacy parallel does not need this, since it will be reset after codegen
|
|
// for this statement is completed
|
|
if self.parallel_mode == ParallelMode::Deep {
|
|
self.timeline_reset_start(ctx)?;
|
|
}
|
|
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
|
|
// not parallel/sequential
|
|
gen_with(self, ctx, stmt)
|
|
}
|
|
}
|
|
|
|
fn gen_rpc_tag(
|
|
ctx: &mut CodeGenContext<'_, '_>,
|
|
ty: Type,
|
|
buffer: &mut Vec<u8>,
|
|
) -> Result<(), String> {
|
|
use nac3core::typecheck::typedef::TypeEnum::*;
|
|
|
|
let int32 = ctx.primitives.int32;
|
|
let int64 = ctx.primitives.int64;
|
|
let float = ctx.primitives.float;
|
|
let bool = ctx.primitives.bool;
|
|
let str = ctx.primitives.str;
|
|
let none = ctx.primitives.none;
|
|
|
|
if ctx.unifier.unioned(ty, int32) {
|
|
buffer.push(b'i');
|
|
} else if ctx.unifier.unioned(ty, int64) {
|
|
buffer.push(b'I');
|
|
} else if ctx.unifier.unioned(ty, float) {
|
|
buffer.push(b'f');
|
|
} else if ctx.unifier.unioned(ty, bool) {
|
|
buffer.push(b'b');
|
|
} else if ctx.unifier.unioned(ty, str) {
|
|
buffer.push(b's');
|
|
} else if ctx.unifier.unioned(ty, none) {
|
|
buffer.push(b'n');
|
|
} else {
|
|
let ty_enum = ctx.unifier.get_ty(ty);
|
|
match &*ty_enum {
|
|
TTuple { ty, is_vararg_ctx: false } => {
|
|
buffer.push(b't');
|
|
buffer.push(ty.len() as u8);
|
|
for ty in ty {
|
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
|
}
|
|
}
|
|
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let ty = iter_type_vars(params).next().unwrap().ty;
|
|
|
|
buffer.push(b'l');
|
|
gen_rpc_tag(ctx, ty, buffer)?;
|
|
}
|
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
let ndarray_ndims = if let TLiteral { values, .. } =
|
|
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
|
{
|
|
if values.len() != 1 {
|
|
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
|
}
|
|
|
|
let value = values[0].clone();
|
|
u64::try_from(value.clone())
|
|
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
|
} else {
|
|
unreachable!()
|
|
};
|
|
assert!(
|
|
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
|
|
"Only NDArrays of sizes between 0 and 255 can be RPCed"
|
|
);
|
|
|
|
buffer.push(b'a');
|
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
|
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
|
}
|
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Formats an RPC argument to conform to the expected format required by `send_value`.
|
|
///
|
|
/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format.
|
|
fn format_rpc_arg<'ctx>(
|
|
generator: &mut dyn CodeGenerator,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
|
|
) -> PointerValue<'ctx> {
|
|
let llvm_i8 = ctx.ctx.i8_type();
|
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
|
|
let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
// NAC3: NDArray = { usize, usize*, T* }
|
|
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, Some(ndims));
|
|
let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None);
|
|
|
|
let llvm_usize_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_arg.get_type().size_type().size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
let llvm_pdata_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
|
|
let dims_buf_sz =
|
|
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
|
|
|
let buffer_size =
|
|
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
|
|
|
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
|
|
|
call_memcpy_generic(
|
|
ctx,
|
|
buffer.base_ptr(ctx, generator),
|
|
llvm_arg.ptr_to_data(ctx),
|
|
llvm_pdata_sizeof,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
|
|
let pbuffer_dims_begin =
|
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
|
call_memcpy_generic(
|
|
ctx,
|
|
pbuffer_dims_begin,
|
|
llvm_arg.shape().base_ptr(ctx, generator),
|
|
dims_buf_sz,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
|
|
buffer.base_ptr(ctx, generator)
|
|
}
|
|
|
|
_ => {
|
|
let arg_slot = generator
|
|
.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))
|
|
.unwrap();
|
|
ctx.builder.build_store(arg_slot, arg).unwrap();
|
|
|
|
ctx.builder
|
|
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
|
|
.map(BasicValueEnum::into_pointer_value)
|
|
.unwrap()
|
|
}
|
|
};
|
|
|
|
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
|
|
|
|
arg_slot
|
|
}
|
|
|
|
/// Formats an RPC return value to conform to the expected format required by NAC3.
|
|
fn format_rpc_ret<'ctx>(
|
|
generator: &mut dyn CodeGenerator,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
ret_ty: Type,
|
|
) -> Option<BasicValueEnum<'ctx>> {
|
|
// -- receive value:
|
|
// T result = {
|
|
// void *ret_ptr = alloca(sizeof(T));
|
|
// void *ptr = ret_ptr;
|
|
// loop: int size = rpc_recv(ptr);
|
|
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
|
// if(size) { ptr = alloca(size); goto loop; }
|
|
// else *(T*)ret_ptr
|
|
// }
|
|
|
|
let llvm_i8 = ctx.ctx.i8_type();
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
|
|
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
|
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
|
});
|
|
|
|
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
|
|
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
|
|
return None;
|
|
}
|
|
|
|
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
|
let current_function = prehead_bb.get_parent().unwrap();
|
|
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
|
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
|
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
|
|
|
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
|
|
|
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
// Round `val` up to its modulo `power_of_two`
|
|
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
val: IntValue<'ctx>,
|
|
power_of_two: IntValue<'ctx>| {
|
|
debug_assert_eq!(
|
|
val.get_type().get_bit_width(),
|
|
power_of_two.get_type().get_bit_width()
|
|
);
|
|
|
|
let llvm_val_t = val.get_type();
|
|
|
|
let max_rem = ctx
|
|
.builder
|
|
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
|
|
.unwrap();
|
|
ctx.builder
|
|
.build_and(
|
|
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
|
ctx.builder.build_not(max_rem, "").unwrap(),
|
|
"",
|
|
)
|
|
.unwrap()
|
|
};
|
|
|
|
// Setup types
|
|
let llvm_ret_ty = NDArrayType::from_unifier_type(generator, ctx, ret_ty);
|
|
let llvm_elem_ty = llvm_ret_ty.element_type();
|
|
|
|
// Allocate the resulting ndarray
|
|
// A condition after format_rpc_ret ensures this will not be popped this off.
|
|
let ndarray = llvm_ret_ty.alloca(generator, ctx, Some("rpc.result"));
|
|
|
|
// Setup ndims
|
|
let ndims = llvm_ret_ty.ndims().unwrap();
|
|
// Set `ndarray.ndims`
|
|
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
|
// Allocate `ndarray.shape` [size_t; ndims]
|
|
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
|
|
|
/*
|
|
ndarray now:
|
|
- .ndims: initialized
|
|
- .shape: allocated but uninitialized .shape
|
|
- .data: uninitialized
|
|
*/
|
|
|
|
let llvm_usize_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
|
.unwrap();
|
|
let llvm_pdata_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(
|
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
|
llvm_usize,
|
|
"",
|
|
)
|
|
.unwrap();
|
|
let llvm_elem_sizeof = ctx
|
|
.builder
|
|
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
|
.unwrap();
|
|
|
|
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
|
// (4 + 4 * ndims) bytes with 8-byte alignment
|
|
let sizeof_dims =
|
|
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
|
let buffer_size =
|
|
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
|
|
|
let stackptr = call_stacksave(ctx, None);
|
|
let buffer =
|
|
type_aligned_alloca(generator, ctx, llvm_i8_8, buffer_size, Some("rpc.buffer"));
|
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
|
|
|
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
|
//
|
|
// The returned value is the number of bytes for `ndarray.data`.
|
|
let ndarray_nbytes = ctx
|
|
.build_call_or_invoke(
|
|
rpc_recv,
|
|
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
|
"rpc.size.next",
|
|
)
|
|
.map(BasicValueEnum::into_int_value)
|
|
.unwrap();
|
|
|
|
// debug_assert(ndarray_nbytes > 0)
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
ctx.make_assert(
|
|
generator,
|
|
ctx.builder
|
|
.build_int_compare(
|
|
IntPredicate::UGT,
|
|
ndarray_nbytes,
|
|
ndarray_nbytes.get_type().const_zero(),
|
|
"",
|
|
)
|
|
.unwrap(),
|
|
"0:AssertionError",
|
|
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
|
[None, None, None],
|
|
ctx.current_loc,
|
|
);
|
|
}
|
|
|
|
// Copy shape from the buffer to `ndarray.shape`.
|
|
let pbuffer_dims =
|
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
|
|
|
call_memcpy_generic(
|
|
ctx,
|
|
ndarray.shape().base_ptr(ctx, generator),
|
|
pbuffer_dims,
|
|
sizeof_dims,
|
|
llvm_i1.const_zero(),
|
|
);
|
|
// Restore stack from before allocation of buffer
|
|
call_stackrestore(ctx, stackptr);
|
|
|
|
// Allocate `ndarray.data`.
|
|
// `ndarray.shape` must be initialized beforehand in this implementation
|
|
// (for ndarray.create_data() to know how many elements to allocate)
|
|
let num_elements =
|
|
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
|
|
|
|
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
let sizeof_data =
|
|
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
|
|
|
ctx.make_assert(
|
|
generator,
|
|
ctx.builder.build_int_compare(IntPredicate::UGE,
|
|
sizeof_data,
|
|
ndarray_nbytes,
|
|
"",
|
|
).unwrap(),
|
|
"0:AssertionError",
|
|
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
|
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
|
ctx.current_loc,
|
|
);
|
|
}
|
|
|
|
ndarray.create_data(generator, ctx, llvm_elem_ty, num_elements);
|
|
|
|
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
|
let ndarray_data_i8 =
|
|
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
|
|
|
// NOTE: Currently on `prehead_bb`
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
|
|
ctx.builder.position_at_end(head_bb);
|
|
|
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
|
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
|
|
|
let alloc_size = ctx
|
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
|
.map(BasicValueEnum::into_int_value)
|
|
.unwrap();
|
|
|
|
let is_done = ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
|
.unwrap();
|
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(alloc_bb);
|
|
// Align the allocation to sizeof(T)
|
|
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
|
let alloc_ptr = ctx
|
|
.builder
|
|
.build_array_alloca(
|
|
llvm_elem_ty,
|
|
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
|
"rpc.alloc",
|
|
)
|
|
.unwrap();
|
|
let alloc_ptr =
|
|
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(tail_bb);
|
|
ndarray.as_base_value().into()
|
|
}
|
|
|
|
_ => {
|
|
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
|
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
ctx.builder.position_at_end(head_bb);
|
|
|
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
|
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
|
let alloc_size = ctx
|
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
|
.unwrap()
|
|
.into_int_value();
|
|
let is_done = ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
|
.unwrap();
|
|
|
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
|
ctx.builder.position_at_end(alloc_bb);
|
|
|
|
let alloc_ptr =
|
|
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
|
let alloc_ptr =
|
|
ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
|
|
|
ctx.builder.position_at_end(tail_bb);
|
|
ctx.builder.build_load(slot, "rpc.result").unwrap()
|
|
}
|
|
};
|
|
|
|
Some(result)
|
|
}
|
|
|
|
fn rpc_codegen_callback_fn<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
|
generator: &mut dyn CodeGenerator,
|
|
is_async: bool,
|
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
|
let int8 = ctx.ctx.i8_type();
|
|
let int32 = ctx.ctx.i32_type();
|
|
let size_type = generator.get_size_type(ctx.ctx);
|
|
let ptr_type = int8.ptr_type(AddressSpace::default());
|
|
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
|
|
|
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
|
// -- setup rpc tags
|
|
let mut tag = Vec::new();
|
|
if obj.is_some() {
|
|
tag.push(b'O');
|
|
}
|
|
for arg in &fun.0.args {
|
|
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
|
}
|
|
tag.push(b':');
|
|
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
|
|
|
let mut hasher = DefaultHasher::new();
|
|
tag.hash(&mut hasher);
|
|
let hash = format!("{}", hasher.finish());
|
|
|
|
let tag_ptr = ctx
|
|
.module
|
|
.get_global(hash.as_str())
|
|
.unwrap_or_else(|| {
|
|
let tag_arr_ptr = ctx.module.add_global(
|
|
int8.array_type(tag.len() as u32),
|
|
None,
|
|
format!("tagptr{}", fun.1 .0).as_str(),
|
|
);
|
|
tag_arr_ptr.set_initializer(&int8.const_array(
|
|
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
|
));
|
|
tag_arr_ptr.set_linkage(Linkage::Private);
|
|
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
|
tag_ptr.set_linkage(Linkage::Private);
|
|
tag_ptr.set_initializer(&ctx.ctx.const_struct(
|
|
&[
|
|
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
|
|
size_type.const_int(tag.len() as u64, false).into(),
|
|
],
|
|
false,
|
|
));
|
|
tag_ptr
|
|
})
|
|
.as_pointer_value();
|
|
|
|
let arg_length = args.len() + usize::from(obj.is_some());
|
|
|
|
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
|
let args_ptr = ctx
|
|
.builder
|
|
.build_array_alloca(
|
|
ptr_type,
|
|
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
|
"argptr",
|
|
)
|
|
.unwrap();
|
|
|
|
// -- rpc args handling
|
|
let mut keys = fun.0.args.clone();
|
|
let mut mapping = HashMap::new();
|
|
for (key, value) in args {
|
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
|
}
|
|
// default value handling
|
|
for k in keys {
|
|
mapping
|
|
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
|
}
|
|
// reorder the parameters
|
|
let mut real_params = fun
|
|
.0
|
|
.args
|
|
.iter()
|
|
.map(|arg| {
|
|
mapping
|
|
.remove(&arg.name)
|
|
.unwrap()
|
|
.to_basic_value_enum(ctx, generator, arg.ty)
|
|
.map(|llvm_val| (llvm_val, arg.ty))
|
|
})
|
|
.collect::<Result<Vec<(_, _)>, _>>()?;
|
|
if let Some(obj) = obj {
|
|
if let ValueEnum::Static(obj_val) = obj.1 {
|
|
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
|
} else {
|
|
// should be an error here...
|
|
panic!("only host object is allowed");
|
|
}
|
|
}
|
|
|
|
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
|
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
|
let arg_ptr = unsafe {
|
|
ctx.builder.build_gep(
|
|
args_ptr,
|
|
&[int32.const_int(i as u64, false)],
|
|
&format!("rpc.arg{i}"),
|
|
)
|
|
}
|
|
.unwrap();
|
|
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
|
}
|
|
|
|
// call
|
|
if is_async {
|
|
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
|
ctx.module.add_function(
|
|
"rpc_send_async",
|
|
ctx.ctx.void_type().fn_type(
|
|
&[
|
|
int32.into(),
|
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
],
|
|
false,
|
|
),
|
|
None,
|
|
)
|
|
});
|
|
ctx.builder
|
|
.build_call(
|
|
rpc_send_async,
|
|
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
|
"rpc.send",
|
|
)
|
|
.unwrap();
|
|
} else {
|
|
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
|
ctx.module.add_function(
|
|
"rpc_send",
|
|
ctx.ctx.void_type().fn_type(
|
|
&[
|
|
int32.into(),
|
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
|
],
|
|
false,
|
|
),
|
|
None,
|
|
)
|
|
});
|
|
ctx.builder
|
|
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
|
.unwrap();
|
|
}
|
|
|
|
// reclaim stack space used by arguments
|
|
call_stackrestore(ctx, stackptr);
|
|
|
|
if is_async {
|
|
// async RPCs do not return any values
|
|
Ok(None)
|
|
} else {
|
|
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
|
|
|
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
|
// An RPC returning an NDArray would not touch here.
|
|
call_stackrestore(ctx, stackptr);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
pub fn attributes_writeback<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
inner_resolver: &InnerResolver,
|
|
host_attributes: &PyObject,
|
|
return_obj: Option<(Type, ValueEnum<'ctx>)>,
|
|
) -> Result<(), String> {
|
|
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
|
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
|
let top_levels = ctx.top_level.definitions.read();
|
|
let globals = inner_resolver.global_value_ids.read();
|
|
let int32 = ctx.ctx.i32_type();
|
|
let zero = int32.const_zero();
|
|
let mut values = Vec::new();
|
|
let mut scratch_buffer = Vec::new();
|
|
|
|
if let Some((ty, obj)) = return_obj {
|
|
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
|
|
}
|
|
|
|
for val in (*globals).values() {
|
|
let val = val.as_ref(py);
|
|
let ty = inner_resolver.get_obj_type(
|
|
py,
|
|
val,
|
|
&mut ctx.unifier,
|
|
&top_levels,
|
|
&ctx.primitives,
|
|
)?;
|
|
if let Err(ty) = ty {
|
|
return Ok(Err(ty));
|
|
}
|
|
let ty = ty.unwrap();
|
|
match &*ctx.unifier.get_ty(ty) {
|
|
TypeEnum::TObj { fields, obj_id, .. }
|
|
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
|
{
|
|
// we only care about primitive attributes
|
|
// for non-primitive attributes, they should be in another global
|
|
let mut attributes = Vec::new();
|
|
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
|
|
for (name, (field_ty, is_mutable)) in fields {
|
|
if !is_mutable {
|
|
continue;
|
|
}
|
|
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
|
attributes.push(name.to_string());
|
|
let (index, _) = ctx.get_attr_index(ty, *name);
|
|
values.push((
|
|
*field_ty,
|
|
ctx.build_gep_and_load(
|
|
obj.into_pointer_value(),
|
|
&[zero, int32.const_int(index as u64, false)],
|
|
None,
|
|
),
|
|
));
|
|
}
|
|
}
|
|
if !attributes.is_empty() {
|
|
let pydict = PyDict::new(py);
|
|
pydict.set_item("obj", val)?;
|
|
pydict.set_item("fields", attributes)?;
|
|
host_attributes.append(pydict)?;
|
|
}
|
|
}
|
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let elem_ty = iter_type_vars(params).next().unwrap().ty;
|
|
|
|
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
|
|
let pydict = PyDict::new(py);
|
|
pydict.set_item("obj", val)?;
|
|
host_attributes.append(pydict)?;
|
|
values.push((
|
|
ty,
|
|
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
|
|
));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
let fun = FunSignature {
|
|
args: values
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, (ty, _))| FuncArg {
|
|
name: i.to_string().into(),
|
|
ty: *ty,
|
|
default_value: None,
|
|
is_vararg: false,
|
|
})
|
|
.collect(),
|
|
ret: ctx.primitives.none,
|
|
vars: VarMap::default(),
|
|
};
|
|
let args: Vec<_> =
|
|
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
|
if let Err(e) =
|
|
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, true)
|
|
{
|
|
return Ok(Err(e));
|
|
}
|
|
Ok(Ok(()))
|
|
})
|
|
.unwrap()?;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
|
|
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
|
rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
|
|
})))
|
|
}
|
|
|
|
/// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with
|
|
/// [`llvm_usize`] as its native word size.
|
|
///
|
|
/// Note that, similar to format constants in `<inttypes.h>`, these constants need to be prepended
|
|
/// with `%`.
|
|
#[must_use]
|
|
fn get_fprintf_format_constant<'ctx>(
|
|
llvm_usize: IntType<'ctx>,
|
|
llvm_int_t: IntType<'ctx>,
|
|
is_unsigned: bool,
|
|
) -> String {
|
|
debug_assert!(matches!(llvm_usize.get_bit_width(), 8 | 16 | 32 | 64));
|
|
|
|
let conv_spec = if is_unsigned { 'u' } else { 'd' };
|
|
|
|
// https://en.cppreference.com/w/c/language/arithmetic_types
|
|
// Note that NAC3 does **not** support LP32 and LLP64 configurations
|
|
match llvm_int_t.get_bit_width() {
|
|
8 => format!("hh{conv_spec}"),
|
|
16 => format!("h{conv_spec}"),
|
|
32 => conv_spec.to_string(),
|
|
64 => format!("{}{conv_spec}", if llvm_usize.get_bit_width() == 64 { "l" } else { "ll" }),
|
|
_ => todo!(
|
|
"Not yet implemented for i{} on {}-bit platform",
|
|
llvm_int_t.get_bit_width(),
|
|
llvm_usize.get_bit_width()
|
|
),
|
|
}
|
|
}
|
|
|
|
/// Prints one or more `values` to `core_log` or `rtio_log`.
|
|
///
|
|
/// * `separator` - The separator between multiple values.
|
|
/// * `suffix` - String to terminate the printed string, if any.
|
|
/// * `as_repr` - Whether the `repr()` output of values instead of `str()`.
|
|
/// * `as_rtio` - Whether to print to `rtio_log` instead of `core_log`.
|
|
fn polymorphic_print<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
values: &[(Type, ValueEnum<'ctx>)],
|
|
separator: &str,
|
|
suffix: Option<&str>,
|
|
as_repr: bool,
|
|
as_rtio: bool,
|
|
) -> Result<(), String> {
|
|
let printf = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
fmt: String,
|
|
args: Vec<BasicValueEnum<'ctx>>| {
|
|
debug_assert!(!fmt.is_empty());
|
|
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
|
|
|
|
let fn_name = if as_rtio { "rtio_log" } else { "core_log" };
|
|
let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
|
|
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
|
let fn_t = if as_rtio {
|
|
let llvm_void = ctx.ctx.void_type();
|
|
llvm_void.fn_type(&[llvm_pi8.into()], true)
|
|
} else {
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
llvm_i32.fn_type(&[llvm_pi8.into()], true)
|
|
};
|
|
ctx.module.add_function(fn_name, fn_t, None)
|
|
});
|
|
|
|
let fmt = ctx.gen_string(generator, fmt);
|
|
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
|
|
ctx.builder
|
|
.build_call(
|
|
print_fn,
|
|
&once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(),
|
|
"",
|
|
)
|
|
.unwrap();
|
|
};
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
let llvm_i64 = ctx.ctx.i64_type();
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
let suffix = suffix.unwrap_or_default();
|
|
|
|
let mut fmt = String::new();
|
|
let mut args = Vec::new();
|
|
|
|
let flush = |ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
fmt: &mut String,
|
|
args: &mut Vec<BasicValueEnum<'ctx>>| {
|
|
if !fmt.is_empty() {
|
|
fmt.push('\0');
|
|
printf(ctx, generator, mem::take(fmt), mem::take(args));
|
|
}
|
|
};
|
|
|
|
for (ty, value) in values {
|
|
let ty = *ty;
|
|
let value = value.clone().to_basic_value_enum(ctx, generator, ty).unwrap();
|
|
|
|
if !fmt.is_empty() {
|
|
fmt.push_str(separator);
|
|
}
|
|
|
|
match &*ctx.unifier.get_ty_immutable(ty) {
|
|
TypeEnum::TTuple { ty: tys, is_vararg_ctx: false } => {
|
|
let pvalue = {
|
|
let pvalue = generator.gen_var_alloc(ctx, value.get_type(), None).unwrap();
|
|
ctx.builder.build_store(pvalue, value).unwrap();
|
|
pvalue
|
|
};
|
|
|
|
fmt.push('(');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let tuple_vals = tys
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, ty)| {
|
|
(*ty, {
|
|
let pfield =
|
|
ctx.builder.build_struct_gep(pvalue, i as u32, "").unwrap();
|
|
|
|
ValueEnum::from(ctx.builder.build_load(pfield, "").unwrap())
|
|
})
|
|
})
|
|
.collect_vec();
|
|
|
|
polymorphic_print(ctx, generator, &tuple_vals, ", ", None, true, as_rtio)?;
|
|
|
|
if tuple_vals.len() == 1 {
|
|
fmt.push_str(",)");
|
|
} else {
|
|
fmt.push(')');
|
|
}
|
|
}
|
|
|
|
TypeEnum::TFunc { .. } => todo!(),
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::None.id() => {
|
|
fmt.push_str("None");
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Bool.id() => {
|
|
fmt.push_str("%.*s");
|
|
|
|
let true_str = ctx.gen_string(generator, "True");
|
|
let true_data =
|
|
unsafe { true_str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let true_len = unsafe { true_str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
let false_str = ctx.gen_string(generator, "False");
|
|
let false_data =
|
|
unsafe { false_str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let false_len =
|
|
unsafe { false_str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
|
|
let bool_val = generator.bool_to_i1(ctx, value.into_int_value());
|
|
|
|
args.extend([
|
|
ctx.builder.build_select(bool_val, true_len, false_len, "").unwrap(),
|
|
ctx.builder.build_select(bool_val, true_data, false_data, "").unwrap(),
|
|
]);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. }
|
|
if *obj_id == PrimDef::Int32.id()
|
|
|| *obj_id == PrimDef::Int64.id()
|
|
|| *obj_id == PrimDef::UInt32.id()
|
|
|| *obj_id == PrimDef::UInt64.id() =>
|
|
{
|
|
let is_unsigned =
|
|
*obj_id == PrimDef::UInt32.id() || *obj_id == PrimDef::UInt64.id();
|
|
|
|
let llvm_int_t = value.get_type().into_int_type();
|
|
debug_assert!(matches!(llvm_usize.get_bit_width(), 32 | 64));
|
|
debug_assert!(matches!(llvm_int_t.get_bit_width(), 32 | 64));
|
|
|
|
let fmt_spec = format!(
|
|
"%{}",
|
|
get_fprintf_format_constant(llvm_usize, llvm_int_t, is_unsigned)
|
|
);
|
|
|
|
fmt.push_str(fmt_spec.as_str());
|
|
args.push(value);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Float.id() => {
|
|
fmt.push_str("%g");
|
|
args.push(value);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Str.id() => {
|
|
if as_repr {
|
|
fmt.push_str("\"%.*s\"");
|
|
} else {
|
|
fmt.push_str("%.*s");
|
|
}
|
|
|
|
let str = value.into_struct_value();
|
|
let str_data = unsafe { str.get_field_at_index_unchecked(0) }.into_pointer_value();
|
|
let str_len = unsafe { str.get_field_at_index_unchecked(1) }.into_int_value();
|
|
|
|
args.extend(&[str_len.into(), str_data.into()]);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
let elem_ty = *params.iter().next().unwrap().1;
|
|
|
|
fmt.push('[');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val =
|
|
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
|
|
let len = val.load_size(ctx, None);
|
|
let last =
|
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
|
|
|
gen_for_callback_incrementing(
|
|
generator,
|
|
ctx,
|
|
None,
|
|
llvm_usize.const_zero(),
|
|
(len, false),
|
|
|generator, ctx, _, i| {
|
|
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(elem_ty, elem.into())],
|
|
"",
|
|
None,
|
|
true,
|
|
as_rtio,
|
|
)?;
|
|
|
|
gen_if_callback(
|
|
generator,
|
|
ctx,
|
|
|_, ctx| {
|
|
Ok(ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::ULT, i, last, "")
|
|
.unwrap())
|
|
},
|
|
|generator, ctx| {
|
|
printf(ctx, generator, ", \0".into(), Vec::default());
|
|
|
|
Ok(())
|
|
},
|
|
|_, _| Ok(()),
|
|
)?;
|
|
|
|
Ok(())
|
|
},
|
|
llvm_usize.const_int(1, false),
|
|
)?;
|
|
|
|
fmt.push(']');
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
|
|
fmt.push_str("array([");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val = NDArrayType::from_unifier_type(generator, ctx, ty)
|
|
.map_value(value.into_pointer_value(), None);
|
|
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
|
|
let last =
|
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
|
|
|
gen_for_callback_incrementing(
|
|
generator,
|
|
ctx,
|
|
None,
|
|
llvm_usize.const_zero(),
|
|
(len, false),
|
|
|generator, ctx, _, i| {
|
|
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(elem_ty, elem.into())],
|
|
"",
|
|
None,
|
|
true,
|
|
as_rtio,
|
|
)?;
|
|
|
|
gen_if_callback(
|
|
generator,
|
|
ctx,
|
|
|_, ctx| {
|
|
Ok(ctx
|
|
.builder
|
|
.build_int_compare(IntPredicate::ULT, i, last, "")
|
|
.unwrap())
|
|
},
|
|
|generator, ctx| {
|
|
printf(ctx, generator, ", \0".into(), Vec::default());
|
|
|
|
Ok(())
|
|
},
|
|
|_, _| Ok(()),
|
|
)?;
|
|
|
|
Ok(())
|
|
},
|
|
llvm_usize.const_int(1, false),
|
|
)?;
|
|
|
|
fmt.push_str(")]");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Range.id() => {
|
|
fmt.push_str("range(");
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
let val = RangeValue::from_pointer_value(value.into_pointer_value(), None);
|
|
|
|
let (start, stop, step) = destructure_range(ctx, val);
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[
|
|
(ctx.primitives.int32, start.into()),
|
|
(ctx.primitives.int32, stop.into()),
|
|
(ctx.primitives.int32, step.into()),
|
|
],
|
|
", ",
|
|
None,
|
|
false,
|
|
as_rtio,
|
|
)?;
|
|
|
|
fmt.push(')');
|
|
}
|
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Exception.id() => {
|
|
let fmt_str = format!(
|
|
"%{}(%{}, %{1:}, %{1:})",
|
|
get_fprintf_format_constant(llvm_usize, llvm_i32, false),
|
|
get_fprintf_format_constant(llvm_usize, llvm_i64, false),
|
|
);
|
|
|
|
let exn = value.into_pointer_value();
|
|
let name = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param0 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(6, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param1 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(7, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
let param2 = ctx
|
|
.build_in_bounds_gep_and_load(
|
|
exn,
|
|
&[llvm_i32.const_zero(), llvm_i32.const_int(8, false)],
|
|
None,
|
|
)
|
|
.into_int_value();
|
|
|
|
fmt.push_str(fmt_str.as_str());
|
|
args.extend_from_slice(&[name.into(), param0.into(), param1.into(), param2.into()]);
|
|
}
|
|
|
|
_ => unreachable!(
|
|
"Unsupported object type for polymorphic_print: {}",
|
|
ctx.unifier.stringify(ty)
|
|
),
|
|
}
|
|
}
|
|
|
|
fmt.push_str(suffix);
|
|
flush(ctx, generator, &mut fmt, &mut args);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Invokes the `core_log` intrinsic function.
|
|
pub fn call_core_log_impl<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
arg: (Type, BasicValueEnum<'ctx>),
|
|
) -> Result<(), String> {
|
|
let (arg_ty, arg_val) = arg;
|
|
|
|
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\n"), false, false)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Invokes the `rtio_log` intrinsic function.
|
|
pub fn call_rtio_log_impl<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
generator: &mut dyn CodeGenerator,
|
|
channel: StructValue<'ctx>,
|
|
arg: (Type, BasicValueEnum<'ctx>),
|
|
) -> Result<(), String> {
|
|
let (arg_ty, arg_val) = arg;
|
|
|
|
polymorphic_print(
|
|
ctx,
|
|
generator,
|
|
&[(ctx.primitives.str, channel.into())],
|
|
" ",
|
|
Some("\x1E"),
|
|
false,
|
|
true,
|
|
)?;
|
|
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\x1D"), false, true)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Generates a call to `core_log`.
|
|
pub fn gen_core_log<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
generator: &mut dyn CodeGenerator,
|
|
) -> Result<(), String> {
|
|
assert!(obj.is_none());
|
|
assert_eq!(args.len(), 1);
|
|
|
|
let value_ty = fun.0.args[0].ty;
|
|
let value_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
|
|
|
|
call_core_log_impl(ctx, generator, (value_ty, value_arg))
|
|
}
|
|
|
|
/// Generates a call to `rtio_log`.
|
|
pub fn gen_rtio_log<'ctx>(
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
fun: (&FunSignature, DefinitionId),
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
generator: &mut dyn CodeGenerator,
|
|
) -> Result<(), String> {
|
|
assert!(obj.is_none());
|
|
assert_eq!(args.len(), 2);
|
|
|
|
let channel_ty = fun.0.args[0].ty;
|
|
assert!(ctx.unifier.unioned(channel_ty, ctx.primitives.str));
|
|
let channel_arg =
|
|
args[0].1.clone().to_basic_value_enum(ctx, generator, channel_ty)?.into_struct_value();
|
|
let value_ty = fun.0.args[1].ty;
|
|
let value_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
|
|
|
|
call_rtio_log_impl(ctx, generator, channel_arg, (value_ty, value_arg))
|
|
}
|