[meta] Refactor Result to anyhow::Result

This commit is contained in:
2026-02-16 11:22:03 +08:00
parent bc3c7f0e21
commit dc670f3df3
56 changed files with 3395 additions and 3410 deletions

2
Cargo.lock generated
View File

@@ -633,6 +633,7 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
name = "nac3artiq"
version = "0.1.0"
dependencies = [
"anyhow",
"indexmap",
"itertools",
"nac3binutils",
@@ -710,6 +711,7 @@ dependencies = [
name = "nac3standalone"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"nac3core",
"parking_lot",

View File

@@ -9,6 +9,7 @@ name = "nac3artiq"
crate-type = ["cdylib"]
[dependencies]
anyhow = "1.0"
indexmap = "2.12"
itertools = "0.14"
pyo3 = { version = "0.28", features = ["extension-module"] }

View File

@@ -7,6 +7,7 @@ from numpy import int32
# Self-referential ProtoRev types
@compile
class ProtoRev8:
cpld: KernelInvariant[CPLD[ProtoRev8]]

View File

@@ -5,7 +5,7 @@ use std::{
sync::Arc,
};
use itertools::Itertools as _;
use anyhow::{anyhow, bail};
use nac3core::{
codegen::{
CodeGenContext, CodeGenerator, VarValue, basic_type_all, bool_to_i1,
@@ -111,10 +111,10 @@ impl<'a> ArtiqCodeGenerator<'a> {
///
/// Direct-`parallel` block context refers to when the generator is generating statements whose
/// closest parent `with` statement is a `with parallel` block.
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> anyhow::Result<()> {
if let Some(start) = self.start.clone() {
let start_val = self.gen_expr(ctx, &start)?.to_basic_value_enum(ctx)?;
self.timeline.emit_at_mu(ctx, start_val);
self.timeline.emit_at_mu(ctx, start_val)?;
}
Ok(())
@@ -137,12 +137,12 @@ impl<'a> ArtiqCodeGenerator<'a> {
ctx: &mut CodeGenContext<'_, '_>,
end: Option<Expr<Option<Type>>>,
store_name: Option<&str>,
) -> Result<(), String> {
) -> anyhow::Result<()> {
if let Some(end) = end {
let old_end = self.gen_expr(ctx, &end)?.to_basic_value_enum(ctx)?;
let now = self.timeline.emit_now_mu(ctx);
let now = self.timeline.emit_now_mu(ctx)?;
let max =
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"))?;
let end_store = self
.gen_store_target(
ctx,
@@ -150,7 +150,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
store_name.map(|name| format!("{name}.addr")).as_deref(),
)?
.unwrap();
typed_store(ctx.builder, end_store, max);
typed_store(ctx.builder, end_store, max)?;
}
Ok(())
@@ -166,7 +166,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -196,7 +196,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>> {
let result = gen_call(self, ctx, obj, fun, params)?;
// Deep parallel emits timeline end-update/timeline-reset after each function call
@@ -212,7 +212,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
) -> anyhow::Result<()> {
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
if items.len() == 1 && items[0].optional_vars.is_none() {
@@ -237,7 +237,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
ctx.var_assignment.get(id)
{
static_value.clone()
} else if let Some(ValueEnum::Static(val)) = resolver.get_symbol_value(*id, ctx) {
} else if let Some(ValueEnum::Static(val)) = resolver.get_symbol_value(*id, ctx)? {
Some(val)
} else {
None
@@ -253,7 +253,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)?.to_basic_value_enum(ctx)?
} else {
self.timeline.emit_now_mu(ctx)
self.timeline.emit_now_mu(ctx)?
};
// Emulate variable allocation, as we need to use the CodeGenContext
@@ -275,8 +275,8 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
let start = self
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
.unwrap();
typed_store(ctx.builder, start, now);
Ok(Some(start_expr)) as Result<_, String>
typed_store(ctx.builder, start, now)?;
anyhow::Ok(Some(start_expr))
},
|v| Ok(Some(v)),
)?;
@@ -288,7 +288,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
custom: Some(ctx.primitives.int64),
};
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
typed_store(ctx.builder, end, now);
typed_store(ctx.builder, end, now)?;
self.end = Some(end_expr);
self.name_counter += 1;
self.parallel_mode = if python_id == self.special_ids.parallel {
@@ -323,7 +323,7 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
// inside a sequential block
if old_start.is_none() {
self.timeline.emit_at_mu(ctx, end_val);
self.timeline.emit_at_mu(ctx, end_val)?;
}
// inside a parallel block, should update the outer max now_mu
@@ -369,7 +369,7 @@ fn gen_rpc_tag(
ctx: &mut CodeGenContext<'_, '_>,
ty: Type,
buffer: &mut Vec<u8>,
) -> Result<(), String> {
) -> anyhow::Result<()> {
let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives;
if ctx.unifier.unioned(ty, int32) {
@@ -406,15 +406,15 @@ fn gen_rpc_tag(
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
{
if values.len() != 1 {
return Err(format!(
bail!(
"NDArray types with multiple literal bounds for ndims is not supported: {}",
ctx.unifier.stringify(ty)
));
);
}
let value = values[0].clone();
u64::try_from(value.clone())
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
.map_err(|()| anyhow!("Expected u64 for ndarray.ndims, got {value}"))?
} else {
unreachable!()
};
@@ -427,7 +427,7 @@ fn gen_rpc_tag(
buffer.push((ndarray_ndims & 0xFF) as u8);
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
}
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
_ => bail!("Unsupported type: {:?}", ctx.unifier.stringify(ty)),
}
}
Ok(())
@@ -439,7 +439,7 @@ fn gen_rpc_tag(
fn format_rpc_arg<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
) -> PointerValue<'ctx> {
) -> anyhow::Result<PointerValue<'ctx>> {
let llvm_i8 = ctx.i8;
let llvm_pi8 = ctx.ptr;
@@ -457,7 +457,7 @@ fn format_rpc_arg<'ctx>(
// `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
// the reader.
// Turning it into a ContiguousNDArray to get a `data` that is contiguous.
let carray = ndarray.make_contiguous_ndarray(ctx);
let carray = ndarray.make_contiguous_ndarray(ctx)?;
let sizeof_usize = ctx.sizeof(ctx.size_t);
let sizeof_pdata = ctx.sizeof(ctx.ptr);
@@ -465,45 +465,42 @@ fn format_rpc_arg<'ctx>(
let sizeof_buf = sizeof_buf_shape + sizeof_pdata;
// buf = { data: void*, shape: [size_t; ndims]; }
let buf = gen_array_var(ctx, llvm_i8, sizeof_buf, Some("rpc.arg"));
let buf = gen_array_var(ctx, llvm_i8, sizeof_buf, Some("rpc.arg"))?;
let buf_data = buf.value.0;
let sizeof_pdata_ = ctx.size_t.const_int(sizeof_pdata, false);
let buf_shape = buf.ptr_offset_unchecked(ctx, &sizeof_pdata_, None);
let buf_shape = buf.ptr_offset_unchecked(ctx, &sizeof_pdata_, None)?;
// Write to `buf->data`
let carray_data = carray.load(ctx, field!(data));
let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "").unwrap();
call_memcpy(ctx, buf_data, carray_data, sizeof_pdata_);
let carray_data = carray.load(ctx, field!(data))?;
let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "")?;
call_memcpy(ctx, buf_data, carray_data, sizeof_pdata_)?;
// Write to `buf->shape`
let carray_shape = ndarray.shape(ctx).value.0;
let carray_shape = ndarray.shape(ctx)?.value.0;
let sizeof_buf_shape_ = ctx.size_t.const_int(sizeof_buf_shape, false);
call_memcpy(ctx, buf_shape, carray_shape, sizeof_buf_shape_);
call_memcpy(ctx, buf_shape, carray_shape, sizeof_buf_shape_)?;
buf.value.0
}
_ => {
let arg_slot = gen_var(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")));
typed_store(ctx.builder, arg_slot, arg);
let arg_slot = gen_var(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))?;
typed_store(ctx.builder, arg_slot, arg)?;
ctx.builder
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
ctx.builder.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")?.into_pointer_value()
}
};
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
arg_slot
Ok(arg_slot)
}
/// Formats an RPC return value to conform to the expected format required by NAC3.
fn format_rpc_ret<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
ret_ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>> {
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
@@ -521,8 +518,8 @@ fn format_rpc_ret<'ctx>(
ctx.declare_external("rpc_recv", Some(llvm_i32.into()), &[llvm_pi8.into()], false, &[]);
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
let _ = ctx.build_call_or_invoke(&rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
return None;
let _ = ctx.build_call_or_invoke(&rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv")?;
return Ok(None);
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
@@ -548,17 +545,13 @@ fn format_rpc_ret<'ctx>(
let llvm_val_t = val.get_type();
let max_rem = ctx
.builder
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
.unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
let max_rem =
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")?;
anyhow::Ok(ctx.builder.build_and(
ctx.builder.build_int_add(val, max_rem, "")?,
ctx.builder.build_not(max_rem, "")?,
"",
)?)
};
// Allocate the resulting ndarray
@@ -566,7 +559,7 @@ fn format_rpc_ret<'ctx>(
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let dtype_llvm = ctx.get_llvm_type(dtype);
let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims).construct(ctx, None);
let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims).construct(ctx, None)?;
// NOTE: Current content of `ndarray`:
// - * `data` - **NOT YET** allocated.
@@ -575,7 +568,7 @@ fn format_rpc_ret<'ctx>(
// - * `shape` - allocated; has uninitialized values.
// - * `strides` - allocated; has uninitialized values.
let stackptr = call_stacksave(ctx, None);
let stackptr = call_stacksave(ctx, None)?;
let itemsize = ctx.sizeof(ndarray.ty.dtype);
let sizeof_ptr = ctx.sizeof(ctx.ptr);
@@ -586,7 +579,7 @@ fn format_rpc_ret<'ctx>(
// Force an aligned allocation.
let chunks = unaligned_buffer_size.div_ceil(8);
let aligned_alloc_ty = ctx.ctx.struct_type(&[ctx.i8.array_type(8).into()], false);
let ptr = gen_array_var(ctx, aligned_alloc_ty, chunks, Some("rpc.buffer")).value.0;
let ptr = gen_array_var(ctx, aligned_alloc_ty, chunks, Some("rpc.buffer"))?.value.0;
let buffer_bytes = ctx.size_t.const_int(chunks * 8, false);
let buffer = ArraySliceValue::new(ctx.i8.into(), ptr, buffer_bytes, None);
@@ -598,16 +591,14 @@ fn format_rpc_ret<'ctx>(
&rpc_recv,
&[buffer.value.0.into()], // Reads [usize; ndims]
"rpc.size.next",
)
)?
.map(BasicValueEnum::into_int_value)
.unwrap();
// debug_assert(ndarray_nbytes > 0)
if ctx.registry.codegen_options.debug {
let cmp = ctx
.builder
.build_int_compare(IntPredicate::UGT, ndarray_nbytes, num_0, "")
.unwrap();
let cmp =
ctx.builder.build_int_compare(IntPredicate::UGT, ndarray_nbytes, num_0, "")?;
ctx.make_assert(
cmp,
@@ -615,39 +606,36 @@ fn format_rpc_ret<'ctx>(
"Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None],
ctx.current_loc,
);
)?;
}
// Copy shape from the buffer to `ndarray.shape`.
// We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
let sizeof_ptr = ctx.size_t.const_int(sizeof_ptr, false);
let pbuffer_shape = buffer.ptr_offset_unchecked(ctx, &sizeof_ptr, None);
ndarray.shape(ctx).memcpy_from(ctx, pbuffer_shape);
let pbuffer_shape = buffer.ptr_offset_unchecked(ctx, &sizeof_ptr, None)?;
ndarray.shape(ctx)?.memcpy_from(ctx, pbuffer_shape)?;
// Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr);
call_stackrestore(ctx, stackptr)?;
// Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate)
ndarray.create_data(ctx); // NOTE: the strides of `ndarray` has also been set to contiguous in `create_data`.
ndarray.create_data(ctx)?; // NOTE: the strides of `ndarray` has also been set to contiguous in `create_data`.
let itemsize = ctx.size_t.const_int(itemsize, false);
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.codegen_options.debug {
let num_elements = ndarray.size(ctx);
let num_elements = ndarray.size(ctx)?;
let expected_ndarray_nbytes =
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
let cmp = ctx
.builder
.build_int_compare(
IntPredicate::UGE,
expected_ndarray_nbytes,
ndarray_nbytes,
"",
)
.unwrap();
ctx.builder.build_int_mul(num_elements, itemsize, "")?;
let cmp = ctx.builder.build_int_compare(
IntPredicate::UGE,
expected_ndarray_nbytes,
ndarray_nbytes,
"",
)?;
ctx.make_assert(
cmp,
@@ -655,73 +643,78 @@ fn format_rpc_ret<'ctx>(
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(expected_ndarray_nbytes), Some(ndarray_nbytes), None],
ctx.current_loc,
);
)?;
}
let ndarray_data = ndarray.load(ctx, field!(data));
let ndarray_data = ndarray.load(ctx, field!(data))?;
let entry_bb = ctx.builder.get_insert_block().unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.build_unconditional_branch(head_bb)?;
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr")?;
phi.add_incoming(&[(&ndarray_data, entry_bb)]);
let alloc_size = ctx
.build_call_or_invoke(&rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.build_call_or_invoke(&rpc_recv, &[phi.as_basic_value()], "rpc.size.next")?
.map(BasicValueEnum::into_int_value)
.unwrap();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
let is_done = ctx.builder.build_int_compare(
IntPredicate::EQ,
llvm_i32.const_zero(),
alloc_size,
"rpc.done",
)?;
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb)?;
ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, itemsize);
let size = ctx.builder.build_int_unsigned_div(alloc_size, itemsize, "").unwrap();
let alloc_ptr = gen_dyn_array_var(ctx, dtype_llvm, size, Some("rpc.alloc")).value.0;
let alloc_size = round_up(ctx, alloc_size, itemsize)?;
let size = ctx.builder.build_int_unsigned_div(alloc_size, itemsize, "")?;
let alloc_ptr = gen_dyn_array_var(ctx, dtype_llvm, size, Some("rpc.alloc"))?.value.0;
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.build_unconditional_branch(head_bb)?;
ctx.builder.position_at_end(tail_bb);
ndarray.value.into()
}
_ => {
let slot = gen_var(ctx, llvm_ret_ty, Some("rpc.ret.slot"));
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
let slot = gen_var(ctx, llvm_ret_ty, Some("rpc.ret.slot"))?;
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr")?;
ctx.builder.build_unconditional_branch(head_bb)?;
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr")?;
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(&rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.build_call_or_invoke(&rpc_recv, &[phi.as_basic_value()], "rpc.size.next")?
.map(BasicValueEnum::into_int_value)
.unwrap();
let is_done = ctx.builder.build_int_compare(
IntPredicate::EQ,
llvm_i32.const_zero(),
alloc_size,
"rpc.done",
)?;
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb)?;
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = gen_dyn_array_var(ctx, llvm_pi8, alloc_size, Some("rpc.alloc")).value.0;
let alloc_ptr =
gen_dyn_array_var(ctx, llvm_pi8, alloc_size, Some("rpc.alloc"))?.value.0;
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.build_unconditional_branch(head_bb)?;
ctx.builder.position_at_end(tail_bb);
ctx.builder.build_load(slot, "rpc.result").unwrap()
ctx.builder.build_load(slot, "rpc.result")?
}
};
Some(result)
Ok(Some(result))
}
fn rpc_codegen_callback_fn<'ctx>(
@@ -730,7 +723,7 @@ fn rpc_codegen_callback_fn<'ctx>(
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>> {
let int8 = ctx.i8;
let int32 = ctx.i32;
let size_type = ctx.size_t;
@@ -781,8 +774,8 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_length = args.len() as u64 + u64::from(obj.is_some());
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = gen_array_var(ctx, ctx.ptr, arg_length, Some("argptr"));
let stackptr = call_stacksave(ctx, Some("rpc.stack"))?;
let args_ptr = gen_array_var(ctx, ctx.ptr, arg_length, Some("argptr"))?;
// -- rpc args handling
let mut keys = fun.0.args.clone();
@@ -792,7 +785,7 @@ fn rpc_codegen_callback_fn<'ctx>(
}
// default value handling
for k in keys {
mapping.insert(k.name, ctx.gen_symbol_val(&k.default_value.unwrap(), k.ty).into());
mapping.insert(k.name, ctx.gen_symbol_val(&k.default_value.unwrap(), k.ty)?.into());
}
// reorder the parameters
let mut real_params = fun
@@ -809,7 +802,7 @@ fn rpc_codegen_callback_fn<'ctx>(
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx), obj.0));
real_params.insert(0, (obj_val.get_const_obj(ctx)?, obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
@@ -817,29 +810,29 @@ fn rpc_codegen_callback_fn<'ctx>(
}
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(ctx, (*arg, *arg_ty, i));
let arg_slot = format_rpc_arg(ctx, (*arg, *arg_ty, i))?;
let name = format!("rpc.arg{i}");
let i = ctx.size_t.const_int(i as u64, false);
args_ptr.set_unchecked(ctx, &i, arg_slot, Some(&name));
args_ptr.set_unchecked(ctx, &i, arg_slot, Some(&name))?;
}
call_extern!(ctx: void "rpc.send" =
(if is_async { "rpc_send_async" } else { "rpc_send" })(service_id, tag_ptr, args_ptr.value.0));
(if is_async { "rpc_send_async" } else { "rpc_send" })(service_id, tag_ptr, args_ptr.value.0))?;
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
call_stackrestore(ctx, stackptr)?;
if is_async {
// async RPCs do not return any values
Ok(None)
} else {
let result = format_rpc_ret(ctx, fun.0.ret);
let result = format_rpc_ret(ctx, fun.0.ret)?;
// Here we call `basic_type_all` to ensure that the return type is not, nor contains, a
// pointer type which may require further allocation, in which case the stack should not
// be restored, as this will lead to undefined behavior.
if result.is_some_and(|res| basic_type_all(&res.get_type(), &|t| !t.is_pointer_type())) {
call_stackrestore(ctx, stackptr);
call_stackrestore(ctx, stackptr)?;
}
Ok(result)
@@ -851,8 +844,8 @@ pub fn attributes_writeback<'ctx>(
inner_resolver: &InnerResolver,
host_attributes: &Py<PyAny>,
return_obj: Option<(Type, ValueEnum<'ctx>)>,
) -> Result<(), String> {
Python::attach(|py| -> PyResult<Result<(), String>> {
) -> anyhow::Result<()> {
Python::attach(|py| -> PyResult<anyhow::Result<()>> {
let host_attributes = host_attributes.cast_bound::<PyList>(py)?;
let int32 = ctx.i32;
let zero = int32.const_zero();
@@ -860,7 +853,13 @@ pub fn attributes_writeback<'ctx>(
let mut scratch_buffer = Vec::new();
if let Some((ty, obj)) = return_obj {
values.push((ty, obj.to_basic_value_enum(ctx, ty).unwrap()));
values.push((
ty,
match obj.to_basic_value_enum(ctx, ty) {
Ok(v) => v,
Err(e) => return Ok(Err(e)),
},
));
}
for val in (*inner_resolver.global_value_ids.read()).values() {
@@ -872,10 +871,10 @@ pub fn attributes_writeback<'ctx>(
&ctx.top_level.definitions.read(),
&ctx.primitives,
)?;
if let Err(ty) = ty {
return Ok(Err(ty));
}
let ty = ty.unwrap();
let ty = match ty {
Ok(ty) => ty,
Err(e) => return Ok(Err(anyhow!("{e}"))),
};
match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let elem_ty = iter_type_vars(params).next().unwrap().ty;
@@ -884,7 +883,11 @@ pub fn attributes_writeback<'ctx>(
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, ty)?.unwrap()));
let value = match inner_resolver.get_obj_value(py, val, ctx, ty)? {
Ok(v) => v.unwrap(),
Err(err) => return Ok(Err(err)),
};
values.push((ty, value));
}
}
TypeEnum::TObj { fields, obj_id, .. }
@@ -893,7 +896,10 @@ pub fn attributes_writeback<'ctx>(
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, ty)?.unwrap();
let obj = match inner_resolver.get_obj_value(py, val, ctx, ty)? {
Ok(v) => v.unwrap(),
Err(err) => return Ok(Err(err)),
};
for (name, (field_ty, attr_kind)) in fields {
if !attr_kind.is_mutable() {
continue;
@@ -903,11 +909,14 @@ pub fn attributes_writeback<'ctx>(
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(
match ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
) {
Ok(v) => v,
Err(e) => return Ok(Err(e)),
},
));
}
}
@@ -940,9 +949,8 @@ pub fn attributes_writeback<'ctx>(
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, true) {
return Ok(Err(e));
}
Ok(Ok(()))
})
.unwrap()?;
Ok(anyhow::Ok(()))
})??;
Ok(())
}
@@ -995,23 +1003,27 @@ fn polymorphic_print<'ctx>(
suffix: Option<&str>,
as_repr: bool,
as_rtio: bool,
) -> Result<(), String> {
let printf =
|ctx: &mut CodeGenContext<'ctx, '_>, fmt: String, args: Vec<BasicValueEnum<'ctx>>| {
debug_assert!(!fmt.is_empty());
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
) -> anyhow::Result<()> {
let printf = |ctx: &mut CodeGenContext<'ctx, '_>,
fmt: String,
args: Vec<BasicValueEnum<'ctx>>|
-> anyhow::Result<()> {
debug_assert!(!fmt.is_empty());
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
let llvm_i32 = ctx.i32;
let llvm_i32 = ctx.i32;
let fmt = ctx.gen_string(fmt);
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
let fmt = ctx.gen_string(fmt)?;
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
if as_rtio {
call_extern!(ctx: void _ = "rtio_log"(fmt; ...args));
} else {
call_extern!(ctx: llvm_i32 _ = "core_log"(fmt; ...args));
}
};
if as_rtio {
call_extern!(ctx: void _ = "rtio_log"(fmt; ...args))?;
} else {
call_extern!(ctx: llvm_i32 _ = "core_log"(fmt; ...args))?;
}
Ok(())
};
let llvm_i32 = ctx.i32;
let llvm_i64 = ctx.i64;
@@ -1027,13 +1039,14 @@ fn polymorphic_print<'ctx>(
args: &mut Vec<BasicValueEnum<'ctx>>| {
if !fmt.is_empty() {
fmt.push('\0');
printf(ctx, mem::take(fmt), mem::take(args));
printf(ctx, mem::take(fmt), mem::take(args))?;
}
anyhow::Ok(())
};
for (ty, value) in values {
let ty = *ty;
let value = value.to_basic_value_enum(ctx, ty).unwrap();
let value = value.to_basic_value_enum(ctx, ty)?;
if !fmt.is_empty() {
fmt.push_str(separator);
@@ -1042,26 +1055,25 @@ fn polymorphic_print<'ctx>(
match &*ctx.unifier.get_ty_immutable(ty) {
TypeEnum::TTuple { ty: tys, is_vararg_ctx: false } => {
let pvalue = {
let pvalue = gen_var(ctx, value.get_type(), None);
typed_store(ctx.builder, pvalue, value);
let pvalue = gen_var(ctx, value.get_type(), None)?;
typed_store(ctx.builder, pvalue, value)?;
pvalue
};
fmt.push('(');
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
let tuple_vals = tys
.iter()
.enumerate()
.map(|(i, ty)| {
(*ty, {
let pfield =
ctx.builder.build_struct_gep(pvalue, i as u32, "").unwrap();
anyhow::Ok((*ty, {
let pfield = ctx.builder.build_struct_gep(pvalue, i as u32, "")?;
ValueEnum::from(ctx.builder.build_load(pfield, "").unwrap())
})
ValueEnum::from(ctx.builder.build_load(pfield, "")?)
}))
})
.collect_vec();
.collect::<Result<Vec<_>, _>>()?;
polymorphic_print(ctx, &tuple_vals, ", ", None, true, as_rtio)?;
@@ -1080,21 +1092,21 @@ fn polymorphic_print<'ctx>(
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Bool.id() => {
fmt.push_str("%.*s");
let true_str = ctx.gen_string("True");
let true_str = ctx.gen_string("True")?;
let true_data =
unsafe { true_str.get_field_at_index_unchecked(0) }.into_pointer_value();
let true_len = unsafe { true_str.get_field_at_index_unchecked(1) }.into_int_value();
let false_str = ctx.gen_string("False");
let false_str = ctx.gen_string("False")?;
let false_data =
unsafe { false_str.get_field_at_index_unchecked(0) }.into_pointer_value();
let false_len =
unsafe { false_str.get_field_at_index_unchecked(1) }.into_int_value();
let bool_val = bool_to_i1(ctx, value.into_int_value());
let bool_val = bool_to_i1(ctx, value.into_int_value())?;
args.extend([
ctx.builder.build_select(bool_val, true_len, false_len, "").unwrap(),
ctx.builder.build_select(bool_val, true_data, false_data, "").unwrap(),
ctx.builder.build_select(bool_val, true_len, false_len, "")?,
ctx.builder.build_select(bool_val, true_data, false_data, "")?,
]);
}
@@ -1143,13 +1155,12 @@ fn polymorphic_print<'ctx>(
let elem_ty = *params.iter().next().unwrap().1;
fmt.push('[');
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
let val = ListType::from_unifier_type(ctx, ty)
.map_value(value.into_pointer_value(), None);
let len = val.load(ctx, field!(len));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
let len = val.load(ctx, field!(len))?;
let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "")?;
gen_for_callback_incrementing(
&mut (),
@@ -1158,7 +1169,7 @@ fn polymorphic_print<'ctx>(
llvm_usize.const_zero(),
(len, false),
|(), ctx, _, i| {
let elem = val.data(ctx).get_unchecked(ctx, &i, None);
let elem = val.data(ctx)?.get_unchecked(ctx, &i, None)?;
polymorphic_print(ctx, &[(elem_ty, elem)], "", None, true, as_rtio)?;
@@ -1166,13 +1177,10 @@ fn polymorphic_print<'ctx>(
&mut (),
ctx,
|(), ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::ULT, i, last, "")
.unwrap())
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, last, "")?)
},
|(), ctx| {
printf(ctx, ", \0".into(), Vec::default());
printf(ctx, ", \0".into(), Vec::default())?;
Ok(())
},
@@ -1186,12 +1194,12 @@ fn polymorphic_print<'ctx>(
)?;
fmt.push(']');
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
fmt.push_str("array([");
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray = NDArrayType::from_unifier_type(ctx, ty)
@@ -1201,22 +1209,20 @@ fn polymorphic_print<'ctx>(
// Print `ndarray` as a flat list delimited by interspersed with ", \0"
ndarray.foreach(ctx, |ctx, _, hdl| {
let i = hdl.get_index(ctx);
let scalar = hdl.get_scalar(ctx);
let i = hdl.get_index(ctx)?;
let scalar = hdl.get_scalar(ctx)?;
// if (i != 0) puts(", ");
gen_if_callback(
&mut (),
ctx,
|(), ctx| {
let not_first = ctx
.builder
.build_int_compare(IntPredicate::NE, i, num_0, "")
.unwrap();
let not_first =
ctx.builder.build_int_compare(IntPredicate::NE, i, num_0, "")?;
Ok(not_first)
},
|(), ctx| {
printf(ctx, ", \0".into(), Vec::default());
printf(ctx, ", \0".into(), Vec::default())?;
Ok(())
},
|(), _| Ok(()),
@@ -1228,16 +1234,16 @@ fn polymorphic_print<'ctx>(
})?;
fmt.push_str(")]");
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Range.id() => {
fmt.push_str("range(");
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
let val = RangeType::new(ctx).map_value(value.into_pointer_value(), None);
let (start, stop, step) = destructure_range(ctx, val);
let (start, stop, step) = destructure_range(ctx, val)?;
polymorphic_print(
ctx,
@@ -1263,10 +1269,10 @@ fn polymorphic_print<'ctx>(
);
let exn = ExceptionType::new(ctx).map_value(value.into_pointer_value(), None);
let name = exn.load(ctx, field!(name));
let param0 = exn.load(ctx, field!(param0));
let param1 = exn.load(ctx, field!(param1));
let param2 = exn.load(ctx, field!(param2));
let name = exn.load(ctx, field!(name))?;
let param0 = exn.load(ctx, field!(param0))?;
let param1 = exn.load(ctx, field!(param1))?;
let param2 = exn.load(ctx, field!(param2))?;
fmt.push_str(fmt_str.as_str());
args.extend_from_slice(&[name.into(), param0.into(), param1.into(), param2.into()]);
@@ -1280,7 +1286,7 @@ fn polymorphic_print<'ctx>(
}
fmt.push_str(suffix);
flush(ctx, &mut fmt, &mut args);
flush(ctx, &mut fmt, &mut args)?;
Ok(())
}
@@ -1289,7 +1295,7 @@ fn polymorphic_print<'ctx>(
pub fn call_core_log_impl<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<(), String> {
) -> anyhow::Result<()> {
let (arg_ty, arg_val) = arg;
polymorphic_print(ctx, &[(arg_ty, arg_val.into())], " ", Some("\n"), false, false)?;
@@ -1302,7 +1308,7 @@ pub fn call_rtio_log_impl<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
channel: StructValue<'ctx>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<(), String> {
) -> anyhow::Result<()> {
let (arg_ty, arg_val) = arg;
polymorphic_print(
@@ -1324,7 +1330,7 @@ pub fn gen_core_log<'ctx>(
obj: Option<&(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<(), String> {
) -> anyhow::Result<()> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@@ -1340,7 +1346,7 @@ pub fn gen_rtio_log<'ctx>(
obj: Option<&(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<(), String> {
) -> anyhow::Result<()> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);

View File

@@ -585,88 +585,107 @@ impl Nac3 {
StmtKind::Import { .. } => true,
StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
// Check if the class is a NAC3 class by looking for `compile` decorator
let nac3_class = Python::attach(|py| {
let nac3_class = Python::attach(|py| -> PyResult<_> {
let module = module.bind(py);
decorator_list.iter().any(|decorator| {
is_decor_fn_same(
decorator,
module,
&[self.primitive_ids.artiq.compile_decor_fn],
)
.unwrap()
decorator_list.iter().try_fold(false, |acc, decorator| {
Ok(acc
|| is_decor_fn_same(
decorator,
module,
&[self.primitive_ids.artiq.compile_decor_fn],
)?)
})
});
})?;
if !nac3_class {
continue;
}
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::attach(|py| -> PyResult<bool> {
let Some((path, id)) = class_expr_id_path(base) else {
return Ok(true);
};
*bases = {
let len = bases.len();
bases.drain(..).try_fold(Vec::with_capacity(len), |mut bases, base| {
let retain = Python::attach(|py| -> PyResult<_> {
let Some((path, id)) = class_expr_id_path(&base) else {
return Ok(true);
};
let module = module.bind(py);
let Some(base_obj) = resolve_qname((path, id), module)? else {
return Ok(false);
};
let base_id = py_interp::extract_id(&base_obj)?;
Ok(base_id == self.primitive_ids.builtins.exception
|| base_id == self.primitive_ids.typing.generic
|| registered_class_ids.contains(&base_id))
})
.unwrap()
});
body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
Python::attach(|py| {
let module = module.bind(py);
let Some(base_obj) = resolve_qname((path, id), module)? else {
return Ok(false);
};
let base_id = py_interp::extract_id(&base_obj)?;
// Keep all class functions decorated with `kernel`, `portable`, or `rpc` decorator
decorator_list.iter().any(|decorator| {
is_decor_fn_same(
decorator,
module,
&[
self.primitive_ids.artiq.kernel_decor_fn,
self.primitive_ids.artiq.portable_decor_fn,
self.primitive_ids.artiq.rpc_decor_fn,
],
)
.unwrap()
})
})
} else {
true
}
});
Ok(base_id == self.primitive_ids.builtins.exception
|| base_id == self.primitive_ids.typing.generic
|| registered_class_ids.contains(&base_id))
})?;
if retain {
bases.push(base);
}
Ok::<_, PyErr>(bases)
})?
};
*body = {
let len = body.len();
body.drain(..).try_fold(Vec::with_capacity(len), |mut body, stmt| {
let retain =
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node
{
Python::attach(|py| -> PyResult<_> {
let module = module.bind(py);
// Keep all class functions decorated with `kernel`, `portable`, or `rpc` decorator
decorator_list.iter().try_fold(false, |acc, decorator| {
Ok(acc
|| is_decor_fn_same(
decorator,
module,
&[
self.primitive_ids.artiq.kernel_decor_fn,
self.primitive_ids.artiq.portable_decor_fn,
self.primitive_ids.artiq.rpc_decor_fn,
],
)?)
})
})?
} else {
true
};
if retain {
body.push(stmt);
}
Ok::<_, PyErr>(body)
})?
};
true
}
StmtKind::FunctionDef { ref decorator_list, .. } => {
Python::attach(|py| {
Python::attach(|py| -> PyResult<_> {
let module = module.bind(py);
// Keep all top-level functions decorated with `extern`, `kernel`, `portable`, or `rpc` decorator
decorator_list.iter().any(|decorator| {
is_decor_fn_same(
decorator,
module,
&[
self.primitive_ids.artiq.extern_decor_fn,
self.primitive_ids.artiq.kernel_decor_fn,
self.primitive_ids.artiq.portable_decor_fn,
self.primitive_ids.artiq.rpc_decor_fn,
],
)
.unwrap()
decorator_list.iter().try_fold(false, |acc, decorator| {
Ok(acc
|| is_decor_fn_same(
decorator,
module,
&[
self.primitive_ids.artiq.extern_decor_fn,
self.primitive_ids.artiq.kernel_decor_fn,
self.primitive_ids.artiq.portable_decor_fn,
self.primitive_ids.artiq.rpc_decor_fn,
],
)?)
})
})
})?
}
_ => false,
@@ -843,11 +862,10 @@ impl Nac3 {
size_t,
);
let store_obj = embedding_map.getattr("store_object").unwrap();
let store_str = embedding_map.getattr("store_str").unwrap();
let store_fun = embedding_map.getattr("store_function").unwrap().into_py_any(py)?;
let host_attributes =
embedding_map.getattr("attributes_writeback").unwrap().into_py_any(py)?;
let store_obj = embedding_map.getattr("store_object")?;
let store_str = embedding_map.getattr("store_str")?;
let store_fun = embedding_map.getattr("store_function")?.into_py_any(py)?;
let host_attributes = embedding_map.getattr("attributes_writeback")?.into_py_any(py)?;
let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new()));
let helper = PythonHelper {
store_obj: Arc::new(store_obj.clone().into_py_any(py)?),
@@ -884,8 +902,8 @@ impl Nac3 {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members = py_module.dict();
for (key, val) in members {
let key: &str = key.extract().unwrap();
let val = py_interp::extract_id(&val).unwrap();
let key: &str = key.extract()?;
let val = py_interp::extract_id(&val)?;
name_to_pyid.insert(key.into(), val);
}
let resolver = Arc::new(Resolver(Arc::new(InnerResolver {
@@ -921,7 +939,7 @@ impl Nac3 {
let py_module = module.bind(py).cast::<PyModule>()?;
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
let class = py_module.getattr(name.to_string().as_str())?;
if py_interp::extract_issubclass(&class, py_interp::get_exception_class(py)?)?
&& class.getattr("artiq_builtin").is_err()
{
@@ -967,15 +985,10 @@ impl Nac3 {
let decor_fn_id = py_interp::extract_id(&decor_fn.into_pyobject(py)?)?;
if decor_fn_id == self.primitive_ids.artiq.rpc_decor_fn {
store_fun
.call1(
py,
(
def_id.0.into_py_any(py)?,
py_module.getattr(name.to_string()).unwrap(),
),
)
.unwrap();
store_fun.call1(
py,
(def_id.0.into_py_any(py)?, py_module.getattr(name.to_string())?),
)?;
let is_async = decorator_list
.iter()
.flat_map(get_decorator_flags)
@@ -1001,7 +1014,7 @@ impl Nac3 {
}
StmtKind::ClassDef { name, body, .. } => {
let class_name = name.to_string();
let class_obj = py_module.getattr(class_name.as_str()).unwrap();
let class_obj = py_module.getattr(class_name.as_str())?;
for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
for decorator in decorator_list {
@@ -1138,7 +1151,7 @@ impl Nac3 {
if let Err(e) = composer.start_analysis(true) {
// report error of __modinit__ separately
return if e.iter().any(|err| err.contains("<nac3_synthesized_modinit>")) {
return if e.iter().any(|err| err.to_string().contains("<nac3_synthesized_modinit>")) {
let msg = Self::report_modinit(
&arg_names,
method_name,
@@ -1149,12 +1162,17 @@ impl Nac3 {
);
Err(CompileError::new_err(format!(
"compilation failed\n----------\n{}",
msg.unwrap_or_else(|| e.iter().sorted().join("\n----------\n"))
msg.unwrap_or_else(|| e
.iter()
.sorted_by(|lhs, rhs| Ord::cmp(&lhs.to_string(), &rhs.to_string()))
.join("\n----------\n"))
)))
} else {
Err(CompileError::new_err(format!(
"compilation failed\n----------\n{}",
e.iter().sorted().join("\n----------\n"),
e.iter()
.sorted_by(|lhs, rhs| Ord::cmp(&lhs.to_string(), &rhs.to_string()))
.join("\n----------\n"),
)))
};
}
@@ -1177,15 +1195,13 @@ impl Nac3 {
&mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen_callback(*is_async));
store_fun
.call1(
py,
(
id.0.into_py_any(py)?,
class_def.getattr(name.to_string().as_str()).unwrap(),
),
)
.unwrap();
store_fun.call1(
py,
(
id.0.into_py_any(py)?,
class_def.getattr(name.to_string().as_str()).unwrap(),
),
)?;
}
}
}
@@ -1294,7 +1310,7 @@ impl Nac3 {
membuffer.lock().push(buffer);
});
embedding_map.setattr("expects_return", has_return).unwrap();
embedding_map.setattr("expects_return", has_return)?;
let emit_llvm_bc = std::env::var(ENV_NAC3_EMIT_LLVM_BC).is_ok();
let emit_llvm_ll = std::env::var(ENV_NAC3_EMIT_LLVM_LL).is_ok();
@@ -1362,13 +1378,12 @@ impl Nac3 {
}
emit_llvm(&main, "main.post-opt");
Python::attach(|py| {
Python::attach(|py| -> PyResult<()> {
let string_store = self.string_store.read();
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
string_store_vec.sort_by(|(_s1, key1), (_s2, key2)| key1.cmp(key2));
for (s, key) in string_store_vec {
let embed_key: i32 =
helper.store_str.bind(py).call1((s,)).unwrap().extract().unwrap();
let embed_key: i32 = helper.store_str.bind(py).call1((s,))?.extract()?;
assert_eq!(
embed_key, *key,
"string {s} is out of sync between embedding map (key={embed_key}) and \
@@ -1376,7 +1391,8 @@ impl Nac3 {
);
}
drop(string_store);
});
Ok(())
})?;
link_fn(&main)
}
@@ -1578,7 +1594,7 @@ impl Nac3 {
"now_mu".into(),
FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx)))
Ok(Some(time_fns.emit_now_mu(ctx)?))
}))),
),
(
@@ -1595,8 +1611,8 @@ impl Nac3 {
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg);
let arg = args[0].1.clone().to_basic_value_enum(ctx, arg_ty)?;
time_fns.emit_at_mu(ctx, arg)?;
Ok(None)
}))),
),
@@ -1614,14 +1630,14 @@ impl Nac3 {
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg);
let arg = args[0].1.clone().to_basic_value_enum(ctx, arg_ty)?;
time_fns.emit_delay_mu(ctx, arg)?;
Ok(None)
}))),
),
];
let get_artiq_builtin_id = |mod_name: Option<&str>, name: &str| -> PyResult<u64> {
let get_artiq_builtin_id = |mod_name: Option<&str>, name: &str| -> PyResult<_> {
let dict = if let Some(mod_name) = mod_name {
artiq_builtins
.get_item(mod_name)?
@@ -1883,7 +1899,7 @@ impl Nac3 {
}
let get_special_ids =
|name: &str| -> PyResult<u64> { special_ids.get_item(name)?.unwrap().extract::<u64>() };
|name: &str| -> PyResult<_> { special_ids.get_item(name)?.unwrap().extract::<u64>() };
self.special_ids = SpecialPythonId {
parallel: get_special_ids("parallel")?,

File diff suppressed because it is too large Load Diff

View File

@@ -6,13 +6,24 @@ use nac3core::{
/// Functions for manipulating the timeline.
pub trait TimeFns {
/// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
fn emit_now_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<BasicValueEnum<'ctx>>;
/// Emits LLVM IR for `at_mu`.
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>);
fn emit_at_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
t: BasicValueEnum<'ctx>,
) -> anyhow::Result<()>;
/// Emits LLVM IR for `delay_mu`.
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>);
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) -> anyhow::Result<()>;
}
pub struct NowPinningTimeFns64 {}
@@ -20,7 +31,10 @@ pub struct NowPinningTimeFns64 {}
// For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit
// values that are each padded to 64-bits.
impl TimeFns for NowPinningTimeFns64 {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
fn emit_now_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<BasicValueEnum<'ctx>> {
let i64_type = ctx.i64;
let i32_type = ctx.i32;
let now = ctx
@@ -29,72 +43,65 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")?
.into_pointer_value();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")?
};
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")?.into_int_value();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")?.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "")?;
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "")?;
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "")?;
Ok(ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into)?)
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
fn emit_at_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
t: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
let i32_type = ctx.i32;
let i64_type = ctx.i64;
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi")?,
i32_type,
"",
)?;
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo")?;
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")?
.into_pointer_value();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
typed_store(ctx.builder, now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
typed_store(ctx.builder, now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")?
};
typed_store(ctx.builder, now_hiptr, time_hi)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
typed_store(ctx.builder, now_loptr, time_lo)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
Ok(())
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
let i64_type = ctx.i64;
let i32_type = ctx.i32;
let now = ctx
@@ -103,52 +110,37 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")?
.into_pointer_value();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")?
};
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")?.into_int_value();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")?.into_int_value();
let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "")?;
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "")?;
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "")?;
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now")?;
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder
.build_right_shift(time, i64_type.const_int(32, false), false, "")
.unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time")?;
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_type.const_int(32, false), false, "")?,
i32_type,
"time.hi",
)?;
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo")?;
typed_store(ctx.builder, now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
typed_store(ctx.builder, now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
typed_store(ctx.builder, now_hiptr, time_hi)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
typed_store(ctx.builder, now_loptr, time_lo)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
Ok(())
}
}
@@ -157,63 +149,65 @@ pub static NOW_PINNING_TIME_FNS_64: NowPinningTimeFns64 = NowPinningTimeFns64 {}
pub struct NowPinningTimeFns {}
impl TimeFns for NowPinningTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
fn emit_now_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<BasicValueEnum<'ctx>> {
let i64_type = ctx.i64;
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now")?.into_int_value();
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo")?;
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi")?;
Ok(ctx.builder.build_or(now_lo, now_hi, "now_mu")?.into())
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
fn emit_at_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
t: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
let i32_type = ctx.i32;
let i64_type = ctx.i64;
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "")?,
i32_type,
"time.hi",
)?;
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc")?;
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")?
.into_pointer_value();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
typed_store(ctx.builder, now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
typed_store(ctx.builder, now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")?
};
typed_store(ctx.builder, now_hiptr, time_hi)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
typed_store(ctx.builder, now_loptr, time_lo)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
Ok(())
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
let i32_type = ctx.i32;
let i64_type = ctx.i64;
let i64_32 = i64_type.const_int(32, false);
@@ -221,43 +215,34 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "")?.into_int_value();
let dt = dt.into_int_value();
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"now_trunc",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo")?;
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi")?;
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val")?;
let time = ctx.builder.build_int_add(now_val, dt, "time")?;
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi")?,
i32_type,
"now_trunc",
)?;
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo")?;
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")?
.into_pointer_value();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
typed_store(ctx.builder, now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
typed_store(ctx.builder, now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")?
};
typed_store(ctx.builder, now_hiptr, time_hi)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
typed_store(ctx.builder, now_loptr, time_lo)?
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)?;
Ok(())
}
}
@@ -266,18 +251,29 @@ pub static NOW_PINNING_TIME_FNS: NowPinningTimeFns = NowPinningTimeFns {};
pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
call_extern!(ctx: (ctx.i64) "now_mu" = "now_mu"()).into()
fn emit_now_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<BasicValueEnum<'ctx>> {
Ok(call_extern!(ctx: (ctx.i64) "now_mu" = "now_mu"())?.into())
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
fn emit_at_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
t: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
assert_eq!(t.get_type(), ctx.i64.into());
call_extern!(ctx: void "at_mu" = "at_mu"(t));
call_extern!(ctx: void "at_mu" = "at_mu"(t))
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) -> anyhow::Result<()> {
assert_eq!(dt.get_type(), ctx.i64.into());
call_extern!(ctx: void "delay_mu" = "delay_mu"(dt));
call_extern!(ctx: void "delay_mu" = "delay_mu"(dt))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use anyhow::anyhow;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
@@ -234,14 +235,14 @@ impl ConcreteTypeStore {
primitives: &PrimitiveStore,
cty: ConcreteType,
cache: &mut HashMap<ConcreteType, Option<Type>>,
) -> Type {
) -> anyhow::Result<Type> {
if let Some(ty) = cache.get_mut(&cty) {
return if let Some(ty) = ty {
return Ok(if let Some(ty) = ty {
*ty
} else {
*ty = Some(unifier.get_dummy_var().ty);
ty.unwrap()
};
});
}
cache.insert(cty, None);
let result = match &self.store[cty.0] {
@@ -259,46 +260,60 @@ impl ConcreteTypeStore {
Primitive::Exception => primitives.exception,
};
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
return Ok(ty);
}
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
.map(|cty| anyhow::Ok(self.to_unifier_type(unifier, primitives, *cty, cache)?))
.collect::<anyhow::Result<_>>()?,
is_vararg_ctx: *is_vararg_ctx,
},
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache)? }
}
ConcreteTypeEnum::TObj { obj_id, fields, params } => TypeEnum::TObj {
obj_id: *obj_id,
fields: fields
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
anyhow::Ok((
*name,
(self.to_unifier_type(unifier, primitives, cty.0, cache)?, cty.1),
))
})
.collect::<HashMap<_, _>>(),
params: into_var_map(params.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
.collect::<anyhow::Result<HashMap<_, _>>>()?,
params: into_var_map(
params
.iter()
.map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache)?;
anyhow::Ok(TypeVar { id, ty })
})
.collect::<anyhow::Result<Vec<_>>>()?,
),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args
.iter()
.map(|arg| FuncArg {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: false,
.map(|arg| {
anyhow::Ok(FuncArg {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache)?,
default_value: arg.default_value.clone(),
is_vararg: false,
})
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: into_var_map(vars.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
.collect::<anyhow::Result<_>>()?,
ret: self.to_unifier_type(unifier, primitives, *ret, cache)?,
vars: into_var_map(
vars.iter()
.map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache)?;
anyhow::Ok(TypeVar { id, ty })
})
.collect::<anyhow::Result<Vec<_>>>()?,
),
}),
ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None }
@@ -306,10 +321,10 @@ impl ConcreteTypeStore {
};
let result = unifier.add_ty(result);
if let Some(ty) = cache.get(&cty).unwrap() {
unifier.unify(*ty, result).unwrap();
unifier.unify(*ty, result).map_err(|e| anyhow!("{}", e.to_display(unifier)))?;
}
cache.insert(cty, Some(result));
result
Ok(result)
}
pub fn add_cty(&mut self, cty: ConcreteTypeEnum) -> ConcreteType {

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ pub fn call_j1<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
let llvm_f64 = ctx.f64;
debug_assert_eq!(arg.get_type(), llvm_f64);
call_extern!(ctx: llvm_f64 name? = ["nounwind"] "j1"(arg))
@@ -41,7 +41,7 @@ macro_rules! generate_linalg_extern_fn {
ctx: &mut CodeGenContext<'ctx, '_>,
$($input_matrix: BasicValueEnum<'ctx>,)*
name: Option<&str>,
) {
) -> anyhow::Result<()> {
call_extern!(ctx: void name? = ["nounwind"] (stringify!($extern_fn))($($input_matrix),*))
}
};

View File

@@ -30,7 +30,7 @@ pub trait CodeGenerator {
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>>
where
Self: Sized,
{
@@ -47,7 +47,7 @@ pub trait CodeGenerator {
signature: &FunSignature,
def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String>
) -> anyhow::Result<BasicValueEnum<'ctx>>
where
Self: Sized,
{
@@ -68,7 +68,7 @@ pub trait CodeGenerator {
obj: Option<Type>,
fun: (&FunSignature, &mut TopLevelDef, String),
id: usize,
) -> Result<String, String> {
) -> anyhow::Result<String> {
gen_func_instance(ctx, obj, fun, id)
}
@@ -77,7 +77,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>,
) -> Result<RtValue<'ctx>, String>
) -> anyhow::Result<RtValue<'ctx>>
where
Self: Sized,
{
@@ -90,7 +90,7 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>,
pattern: &Expr<Option<Type>>,
name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String>
) -> anyhow::Result<Option<PointerValue<'ctx>>>
where
Self: Sized,
{
@@ -104,7 +104,7 @@ pub trait CodeGenerator {
target: &Expr<Option<Type>>,
value: &ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -120,7 +120,7 @@ pub trait CodeGenerator {
targets: &[Expr<Option<Type>>],
value: &ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -137,7 +137,7 @@ pub trait CodeGenerator {
key: &Expr<Option<Type>>,
value: &ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -150,7 +150,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -163,7 +163,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -176,7 +176,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -187,7 +187,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -201,7 +201,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{
@@ -213,7 +213,7 @@ pub trait CodeGenerator {
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmts: I,
) -> Result<(), String>
) -> anyhow::Result<()>
where
Self: Sized,
{

View File

@@ -8,7 +8,7 @@ pub fn call_isinf<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
assert_eq!(v.get_type(), ctx.f64);
call_extern!(ctx: (ctx.i1) name? = "__nac3_isinf"(v))
}
@@ -19,7 +19,7 @@ pub fn call_isnan<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
assert_eq!(v.get_type(), ctx.f64);
call_extern!(ctx: (ctx.i1) name? = "__nac3_isnan"(v))
}
@@ -40,7 +40,7 @@ macro_rules! generate_f64_nary_fn {
ctx: &mut CodeGenContext<'ctx, '_>,
$($args: FloatValue<'ctx>,)*
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
const FN_NAME: &str = (concat!("__nac3_", stringify!($builtin_fn)));
let llvm_f64 = ctx.f64;
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
@@ -74,7 +74,7 @@ pub fn call_ldexp<'ctx>(
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
let llvm_f64 = ctx.f64;
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), ctx.i32);

View File

@@ -1,8 +1,4 @@
use inkwell::{
IntPredicate,
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue},
};
use inkwell::{IntPredicate, types::BasicTypeEnum, values::IntValue};
use crate::codegen::{
CodeGenContext,
@@ -23,7 +19,7 @@ pub fn list_slice_assignment<'ctx>(
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
) -> anyhow::Result<()> {
let llvm_usize = ctx.size_t;
let llvm_i32 = ctx.i32;
@@ -38,12 +34,10 @@ pub fn list_slice_assignment<'ctx>(
let zero = llvm_i32.const_zero();
let one = llvm_i32.const_int(1, false);
let (dest_ptr, dest_len) = dest_arr.data(ctx).value;
let dest_len =
ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap();
let (src_ptr, src_len) = src_arr.data(ctx).value;
let src_len =
ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap();
let (dest_ptr, dest_len) = dest_arr.data(ctx)?.value;
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32")?;
let (src_ptr, src_len) = src_arr.data(ctx)?.value;
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32")?;
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
@@ -51,51 +45,50 @@ pub fn list_slice_assignment<'ctx>(
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg")?,
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one")?,
ctx.builder.build_int_add(src_idx.1, one, "e_add_one")?,
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
)?
.into_int_value();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg")?,
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one")?,
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one")?,
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len = calculate_len_for_slice_range(ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len = calculate_len_for_slice_range(ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
)?
.into_int_value();
let src_slice_len = calculate_len_for_slice_range(ctx, src_idx.0, src_end, src_idx.2)?;
let dest_slice_len = calculate_len_for_slice_range(ctx, dest_idx.0, dest_end, dest_idx.2)?;
let src_eq_dest = ctx.builder.build_int_compare(
IntPredicate::EQ,
src_slice_len,
dest_slice_len,
"slice_src_eq_dest",
)?;
let src_slt_dest = ctx.builder.build_int_compare(
IntPredicate::SLT,
src_slice_len,
dest_slice_len,
"slice_src_slt_dest",
)?;
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)?;
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1")?;
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond")?;
ctx.make_assert(
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
)?;
let new_len = call_extern!(ctx: llvm_i32 "slice_assign" = fun_symbol(
dest_idx.0, // dest start idx
@@ -116,27 +109,25 @@ pub fn list_slice_assignment<'ctx>(
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap()
ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size")?
}
));
))?;
// update length
gen_if_callback(
&mut (),
ctx,
|(), ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")
.unwrap())
Ok(ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")?)
},
|(), ctx| {
let new_len =
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
dest_arr.store(ctx, field!(len), new_len);
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len")?;
dest_arr.store(ctx, field!(len), new_len)?;
Ok(())
},
|(), _| Ok(()),
)
.unwrap();
)?;
Ok(())
}

View File

@@ -12,7 +12,7 @@ pub fn integer_power<'ctx>(
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
let base_type = base.get_type();
let symbol = match (base_type.get_bit_width(), exp.get_type().get_bit_width(), signed) {
@@ -24,22 +24,19 @@ pub fn integer_power<'ctx>(
};
// throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
let ge_zero = ctx.builder.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)?;
ctx.make_assert(
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
)?;
call_extern!(ctx: base_type "call_int_pow" = symbol(base, exp))
}
@@ -48,14 +45,17 @@ pub fn integer_power<'ctx>(
pub fn call_gammaln<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
let llvm_f64 = ctx.f64;
assert_eq!(v.get_type(), llvm_f64);
call_extern!(ctx: llvm_f64 "gammaln" = "__nac3_gammaln"(v))
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &mut CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
pub fn call_j0<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> anyhow::Result<FloatValue<'ctx>> {
let llvm_f64 = ctx.f64;
assert_eq!(v.get_type(), llvm_f64);
call_extern!(ctx: llvm_f64 "j0" = "__nac3_j0"(v))

View File

@@ -135,12 +135,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
) -> anyhow::Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>> {
let llvm_i32 = ctx.i32;
let zero = llvm_i32.const_zero();
let one = llvm_i32.const_int(1, false);
let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap();
let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32")?;
Ok(Some(match (start, end, step) {
(s, e, None) => (
if let Some(s) = s.as_ref() {
@@ -160,7 +160,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
} else {
length
};
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
ctx.builder.build_int_sub(e, one, "final_end")?
},
one,
),
@@ -168,27 +168,22 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = generator.gen_expr(ctx, step)?.to_basic_value_enum(ctx)?.into_int_value();
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)?;
ctx.make_assert(
not_zero,
"0:ValueError",
"slice step cannot be zero",
[None, None, None],
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
)?;
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1")?;
let neg =
ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")?;
(
match s {
Some(s) => {
@@ -197,32 +192,26 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
};
ctx.builder
.build_select(
ctx.builder
.build_and(
ctx.builder
.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)
.unwrap(),
neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
ctx.builder.build_and(
ctx.builder.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)?,
neg,
"should_minus_one",
)?,
ctx.builder.build_int_sub(s, one, "s_min")?,
s,
"final_start",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
.map(BasicValueEnum::into_int_value)?
}
None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
.map(BasicValueEnum::into_int_value)?,
},
match e {
Some(e) => {
@@ -232,18 +221,16 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
ctx.builder
.build_select(
neg,
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
ctx.builder.build_int_add(e, one, "end_add_one")?,
ctx.builder.build_int_sub(e, one, "end_sub_one")?,
"final_end",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
.map(BasicValueEnum::into_int_value)?
}
None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
.map(BasicValueEnum::into_int_value)?,
},
step,
)

View File

@@ -14,24 +14,26 @@ pub fn calculate_len_for_slice_range<'ctx>(
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
let llvm_i32 = ctx.i32;
assert_eq!(start.get_type(), llvm_i32);
assert_eq!(end.get_type(), llvm_i32);
assert_eq!(step.get_type(), llvm_i32);
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)?;
ctx.make_assert(
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
)?;
call_extern!(ctx: llvm_i32 "calc_len" = "__nac3_range_slice_len"(start, end, step))
}

View File

@@ -13,11 +13,11 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
) -> anyhow::Result<Option<IntValue<'ctx>>> {
let llvm_i32 = ctx.i32;
assert_eq!(length.get_type(), llvm_i32);
let i = generator.gen_expr(ctx, i)?.to_basic_value_enum(ctx)?;
Ok(Some(call_extern!(ctx: llvm_i32 "bounded_ind" = "__nac3_slice_index_bound"(i, length))))
Ok(Some(call_extern!(ctx: llvm_i32 "bounded_ind" = "__nac3_slice_index_bound"(i, length))?))
}

View File

@@ -9,15 +9,15 @@ pub fn call_string_eq<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
str1: StringValue<'ctx>,
str2: StringValue<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
let llvm_i1 = ctx.i1;
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
call_extern!(ctx: llvm_i1 "str_eq_call" = func_name(
str1.ptr(ctx),
str1.len(ctx),
str2.ptr(ctx),
str2.len(ctx),
str1.ptr(ctx)?,
str1.len(ctx)?,
str2.ptr(ctx)?,
str2.len(ctx)?,
))
}

View File

@@ -213,9 +213,9 @@ impl<'ctx> FunctionStore<'ctx> {
decl: &FunctionDecl<'ctx>,
builder: &Builder<'ctx>,
args: &[T],
call: impl FnOnce(FunctionValue<'ctx>, &[T]) -> CallSiteValue<'ctx>,
mut alloca: impl FnMut(BasicTypeEnum<'ctx>) -> PointerValue<'ctx>,
) -> Option<BasicValueEnum<'ctx>>
call: impl FnOnce(FunctionValue<'ctx>, &[T]) -> anyhow::Result<CallSiteValue<'ctx>>,
mut alloca: impl FnMut(BasicTypeEnum<'ctx>) -> anyhow::Result<PointerValue<'ctx>>,
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>>
where
T: Copy + TryInto<BasicValueEnum<'ctx>, Error: Debug>,
BasicValueEnum<'ctx>: Into<T>,
@@ -232,9 +232,9 @@ impl<'ctx> FunctionStore<'ctx> {
if param.get_element_type().is_struct_type()
&& p.get_type().get_element_type().is_struct_type()
{
arg = ptr_to_t(builder.build_pointer_cast(p, param, "").unwrap());
arg = ptr_to_t(builder.build_pointer_cast(p, param, "")?);
}
arg
anyhow::Ok(arg)
};
let (value, ref info) = self.functions[&decl.name];
@@ -244,24 +244,27 @@ impl<'ctx> FunctionStore<'ctx> {
let slot = match *ret {
Some(TyAndCallConv { ty, call_conv: ArgCallConv::Indirect(attr) }) => {
Some((alloca(ty), attr))
Some((alloca(ty)?, attr))
}
_ => None,
};
let normal_args = params.iter().map(|&TyAndCallConv { ty, call_conv }| {
let mut next = *args.next().expect("arguments fewer than parameters");
if let BasicTypeEnum::PointerType(p) = ty {
next = fixup_ptr_arg(next, p);
}
let normal_args = params
.iter()
.map(|&TyAndCallConv { ty, call_conv }| {
let mut next = *args.next().expect("arguments fewer than parameters");
if let BasicTypeEnum::PointerType(p) = ty {
next = fixup_ptr_arg(next, p)?;
}
if let ArgCallConv::Indirect(attr) = call_conv {
let p = alloca(ty);
typed_store(builder, p, next.try_into().unwrap());
(ptr_to_t(p), attr)
} else {
(next, None)
}
});
if let ArgCallConv::Indirect(attr) = call_conv {
let p = alloca(ty)?;
typed_store(builder, p, next.try_into().unwrap())?;
anyhow::Ok((ptr_to_t(p), attr))
} else {
Ok((next, None))
}
})
.collect::<Result<Vec<_>, _>>()?;
let normal_slot = slot.map(|(p, attr)| (ptr_to_t(p), attr));
let (mut llvm_args, attrs): (Vec<_>, Vec<_>) =
@@ -272,7 +275,7 @@ impl<'ctx> FunctionStore<'ctx> {
assert!(args.as_slice().is_empty(), "too many arguments");
}
let result = call(value, &llvm_args);
let result = call(value, &llvm_args)?;
for (loc, attr) in get_attrs(attrs) {
result.add_attribute(loc, attr);
}
@@ -280,10 +283,10 @@ impl<'ctx> FunctionStore<'ctx> {
let mut result = result.try_as_basic_value().basic();
if let Some((ptr, _)) = slot {
assert!(result.is_none());
result = Some(builder.build_load(ptr, "slot").unwrap());
result = Some(builder.build_load(ptr, "slot")?);
}
assert_eq!(result.map(|val| val.get_type()), ret.map(|ret_type| ret_type.ty));
result
Ok(result)
}
FunctionInfo::Internal { ret, params, export } => {
assert!(!export, "attempted to call a non-exported function");
@@ -295,16 +298,16 @@ impl<'ctx> FunctionStore<'ctx> {
if let BasicMetadataTypeEnum::PointerType(p) = *param {
fixup_ptr_arg(arg, p)
} else {
arg
Ok(arg)
}
})
.collect_vec();
.collect::<Result<Vec<_>, _>>()?;
let inst = call(value, &args);
let inst = call(value, &args)?;
inst.set_call_convention(INTERNAL_CALL_CONV);
let result = inst.try_as_basic_value().basic();
assert_eq!(result.map(|val| val.get_type()), *ret);
result
Ok(result)
}
}
}

View File

@@ -14,12 +14,12 @@ fn call_intrinsic_impl<'ctx>(
type_params: &[BasicTypeEnum<'ctx>],
args: &[BasicMetadataValueEnum<'ctx>],
call_name: Option<&str>,
) -> Option<BasicValueEnum<'ctx>> {
) -> anyhow::Result<Option<BasicValueEnum<'ctx>>> {
let intrin = Intrinsic::find(intrin)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, type_params))
.expect("intrinsic not found");
let result = ctx.builder.build_call(intrin, args, call_name.unwrap_or_default()).unwrap();
result.try_as_basic_value().basic()
let result = ctx.builder.build_call(intrin, args, call_name.unwrap_or_default())?;
Ok(result.try_as_basic_value().basic())
}
macro_rules! call_intrinsic {
@@ -27,16 +27,18 @@ macro_rules! call_intrinsic {
call_intrinsic_impl($ctx, concat!("llvm.", $intrin), &[$($($type_param.into()),*)?], &[$($arg.into()),*], $call_name)
}};
($ctx: expr, $call_name: expr, $intrin:literal $([$($type_param:expr),*])? ($($arg:expr),*) -> void) => {{
assert!(call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).is_none())
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).map(|v| {
assert!(v.is_none());
})
}};
($ctx: expr, $call_name: expr, $intrin:literal $([$($type_param:expr),*])? ($($arg:expr),*) -> int) => {{
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).unwrap().into_int_value()
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).map(|v| v.unwrap().into_int_value())
}};
($ctx: expr, $call_name: expr, $intrin:literal $([$($type_param:expr),*])? ($($arg:expr),*) -> float) => {{
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).unwrap().into_float_value()
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).map(|v| v.unwrap().into_float_value())
}};
($ctx: expr, $call_name: expr, $intrin:literal $([$($type_param:expr),*])? ($($arg:expr),*) -> ptr) => {{
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).unwrap().into_pointer_value()
call_intrinsic!($ctx, $call_name, $intrin $([$($type_param),*])? ($($arg),*)).map(|v| v.unwrap().into_pointer_value())
}};
}
@@ -53,29 +55,37 @@ macro_rules! llvm_doc {
}
#[doc = llvm_doc!("va_start")]
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
call_intrinsic!(ctx, None, "va_start"(arglist) -> void);
pub fn call_va_start<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arglist: PointerValue<'ctx>,
) -> anyhow::Result<()> {
call_intrinsic!(ctx, None, "va_start"(arglist) -> void)
}
#[doc = llvm_doc!("va_end")]
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
call_intrinsic!(ctx, None, "va_end"(arglist) -> void);
pub fn call_va_end<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arglist: PointerValue<'ctx>,
) -> anyhow::Result<()> {
call_intrinsic!(ctx, None, "va_end"(arglist) -> void)
}
#[doc = llvm_doc!("va_stacksave")]
#[must_use]
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
) -> anyhow::Result<PointerValue<'ctx>> {
call_intrinsic!(ctx, name, "stacksave"() -> ptr)
}
#[doc = llvm_doc!("va_stackrestore")]
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
call_intrinsic!(ctx, None, "stackrestore"(ptr) -> void);
pub fn call_stackrestore<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
) -> anyhow::Result<()> {
call_intrinsic!(ctx, None, "stackrestore"(ptr) -> void)
}
#[doc = llvm_doc!("memcpy")]
@@ -88,7 +98,7 @@ pub fn call_memcpy<'ctx>(
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
) {
) -> anyhow::Result<()> {
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
@@ -104,7 +114,8 @@ pub fn call_memcpy<'ctx>(
let dest_alignment = target_data.get_abi_alignment(&llvm_dest_t);
let src_alignment = target_data.get_abi_alignment(&llvm_src_t);
ctx.builder.build_memcpy(dest, dest_alignment, src, src_alignment, len).unwrap();
ctx.builder.build_memcpy(dest, dest_alignment, src, src_alignment, len)?;
Ok(())
}
#[doc = llvm_doc!("memcpy")]
@@ -116,7 +127,7 @@ pub fn call_memcpy_generic<'ctx>(
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
) {
) -> anyhow::Result<()> {
let llvm_p0i8 = ctx.ptr;
let dest_elem_t = dest.get_type().get_element_type();
@@ -125,21 +136,15 @@ pub fn call_memcpy_generic<'ctx>(
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
ctx.builder.build_bit_cast(dest, llvm_p0i8, "")?.into_pointer_value()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
ctx.builder.build_bit_cast(src, llvm_p0i8, "").map(BasicValueEnum::into_pointer_value)?
};
call_memcpy(ctx, dest, src, len);
call_memcpy(ctx, dest, src, len)
}
#[doc = llvm_doc!("memcpy")]
@@ -153,7 +158,7 @@ pub fn call_memcpy_generic_array<'ctx>(
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
) {
) -> anyhow::Result<()> {
let llvm_p0i8 = ctx.ptr;
let llvm_usize = ctx.size_t;
@@ -163,27 +168,22 @@ pub fn call_memcpy_generic_array<'ctx>(
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
ctx.builder.build_bit_cast(dest, llvm_p0i8, "")?.into_pointer_value()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
ctx.builder.build_bit_cast(src, llvm_p0i8, "")?.into_pointer_value()
};
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(src_elem_t.size_of().unwrap(), llvm_usize, "")
.unwrap();
let len = ctx.builder.build_int_mul(len, sizeof_elem, "").unwrap();
let sizeof_elem = ctx.builder.build_int_truncate_or_bit_cast(
src_elem_t.size_of().unwrap(),
llvm_usize,
"",
)?;
let len = ctx.builder.build_int_mul(len, sizeof_elem, "")?;
call_memcpy(ctx, dest, src, len);
call_memcpy(ctx, dest, src, len)
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
@@ -202,7 +202,7 @@ macro_rules! generate_llvm_intrinsic_fn {
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
call_intrinsic!(ctx, name, $llvm_name[$val.get_type()]($val) -> float)
}
};
@@ -213,7 +213,7 @@ macro_rules! generate_llvm_intrinsic_fn {
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
call_intrinsic!(ctx, name, $llvm_name[$val1.get_type()]($val1, $val2) -> float)
}
@@ -225,7 +225,7 @@ macro_rules! generate_llvm_intrinsic_fn {
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
call_intrinsic!(ctx, name, $llvm_name[$val1.get_type()]($val1, $val2) -> int)
}
@@ -236,13 +236,12 @@ macro_rules! generate_llvm_intrinsic_fn {
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
#[must_use]
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
@@ -274,22 +273,20 @@ generate_llvm_intrinsic_fn!(float call_float_round: "round"(val));
generate_llvm_intrinsic_fn!(float call_float_rint: "rint"(val));
#[doc = llvm_doc!("powi")]
#[must_use]
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
) -> anyhow::Result<FloatValue<'ctx>> {
call_intrinsic!(ctx, name, "powi"[val.get_type(), power.get_type()](val, power) -> float)
}
#[doc = llvm_doc!("ctpop")]
#[must_use]
pub fn call_int_ctpop<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
call_intrinsic!(ctx, name, "ctpop"[src.get_type()](src) -> int)
}

View File

@@ -7,7 +7,7 @@
use std::{
cell::OnceCell,
collections::{HashMap, HashSet},
collections::HashMap,
ops::ControlFlow,
sync::{
Arc,
@@ -16,6 +16,7 @@ use std::{
thread,
};
use anyhow::anyhow;
use crossbeam::channel::{Receiver, Sender, unbounded};
use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel,
@@ -447,23 +448,19 @@ impl WorkerRegistry {
let builder = context.ctx.create_builder();
let mut unifier_cache = vec![OnceCell::new(); self.top_level_ctx.unifiers.read().len()];
let mut errors = HashSet::new();
let mut errors = Vec::new();
while let Some(task) = self.receiver.recv().unwrap() {
let result =
gen_func(&mut context, &builder, generator, self, task, &mut unifier_cache);
if let Err(e) = result {
errors.insert(e);
errors.push(e);
context =
ModuleContext::new(ctx, &format!("{}_recover", generator.get_name()), options);
}
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().join("\n----------\n"));
let result = context.module.verify();
if let Err(err) = result {
@@ -596,9 +593,9 @@ pub fn typed_load<'ctx>(
ptr: PointerValue<'ctx>,
ty: BasicTypeEnum<'ctx>,
name: &str,
) -> BasicValueEnum<'ctx> {
let casted_ptr = b.build_pointer_cast(ptr, ty.ptr_type(AddressSpace::default()), "").unwrap();
b.build_load(casted_ptr, name).unwrap()
) -> anyhow::Result<BasicValueEnum<'ctx>> {
let casted_ptr = b.build_pointer_cast(ptr, ty.ptr_type(AddressSpace::default()), "")?;
Ok(b.build_load(casted_ptr, name)?)
}
/// Stores `value` into the memory location pointed to by `ptr`.
@@ -609,11 +606,10 @@ pub fn typed_store<'ctx>(
b: &Builder<'ctx>,
ptr: PointerValue<'ctx>,
value: impl BasicValue<'ctx>,
) -> InstructionValue<'ctx> {
) -> anyhow::Result<InstructionValue<'ctx>> {
let value_ty = value.as_basic_value_enum().get_type();
let casted_ptr =
b.build_pointer_cast(ptr, value_ty.ptr_type(AddressSpace::default()), "").unwrap();
b.build_store(casted_ptr, value).unwrap()
let casted_ptr = b.build_pointer_cast(ptr, value_ty.ptr_type(AddressSpace::default()), "")?;
Ok(b.build_store(casted_ptr, value)?)
}
/// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`].
@@ -701,7 +697,7 @@ pub fn gen_func_impl<
'ctx,
'a,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
F: FnOnce(&mut G, &mut CodeGenContext) -> anyhow::Result<()>,
>(
ctx: &'a mut ModuleContext<'ctx>,
builder: &'a Builder<'ctx>,
@@ -710,7 +706,7 @@ pub fn gen_func_impl<
task: CodeGenTask,
unifier_cache: &mut [OnceCell<Unifier>],
codegen_function: F,
) -> Result<FunctionValue<'ctx>, String> {
) -> anyhow::Result<FunctionValue<'ctx>> {
let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone();
let (mut unifier, primitives) = {
@@ -733,7 +729,7 @@ pub fn gen_func_impl<
for (a, b) in &task.subst {
// this should be unification between variables and concrete types
// and should not cause any problem...
let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache);
let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache)?;
unifier
.unify(*a, b)
.or_else(|err| {
@@ -744,7 +740,7 @@ pub fn gen_func_impl<
Err(err)
}
})
.unwrap();
.map_err(|e| anyhow!("{}", e.to_display(&unifier)))?;
}
// rebuild primitive store with unique representatives
@@ -790,7 +786,7 @@ pub fn gen_func_impl<
// We should not be converting back into typechecking/unifier types and then turning that into
// native LLVM types.
let ret = task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache);
let ret = task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache)?;
let ret_type = if unifier.unioned(ret, primitives.none) {
None
} else {
@@ -799,13 +795,15 @@ pub fn gen_func_impl<
let params = args
.iter()
.map(|arg| FuncArg {
name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
.map(|arg| {
anyhow::Ok(FuncArg {
name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache)?,
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
})
.collect_vec();
.collect::<anyhow::Result<Vec<_>>>()?;
let params_type = params
.iter()
.map(|arg| {
@@ -833,7 +831,7 @@ pub fn gen_func_impl<
let param = fn_val.get_nth_param(n as u32).unwrap();
let local_type = get_llvm_type(ctx, &mut unifier, &mut type_cache, arg.ty);
let alloca =
builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string()))?;
// Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() {
@@ -841,7 +839,7 @@ pub fn gen_func_impl<
let param_val = param.into_int_value();
if expected_ty.get_bit_width() == 8 && param_val.get_type().get_bit_width() == 1 {
bool_to_int_type(builder, param_val, ctx.i8)
bool_to_int_type(builder, param_val, ctx.i8)?
} else {
param_val
}
@@ -850,13 +848,14 @@ pub fn gen_func_impl<
param
};
typed_store(builder, alloca, param);
typed_store(builder, alloca, param)?;
var_assignment.insert(arg.name, VarValue::new(alloca));
}
// TODO: Save vararg parameters as list
let return_buffer = ret_type.map(|v| builder.build_alloca(v, "$ret").unwrap());
let return_buffer =
ret_type.map(|v| anyhow::Ok(builder.build_alloca(v, "$ret")?)).transpose()?;
let static_values = {
let store = registry.static_value_store.lock();
@@ -870,11 +869,11 @@ pub fn gen_func_impl<
let exception_val = {
let exn_type = ExceptionType::new(ctx).inner.llvm_ty;
let ptr = builder.build_alloca(exn_type, "exn").unwrap();
builder.build_pointer_cast(ptr, ctx.ptr, "exn").unwrap()
let ptr = builder.build_alloca(exn_type, "exn")?;
builder.build_pointer_cast(ptr, ctx.ptr, "exn")?
};
builder.build_unconditional_branch(body_bb).unwrap();
builder.build_unconditional_branch(body_bb)?;
builder.position_at_end(body_bb);
let is_optimized = registry.codegen_options.opt_level != "0";
@@ -904,8 +903,7 @@ pub fn gen_func_impl<
compile_unit.get_file(),
Some(
dibuilder
.create_basic_type("_", 0_u64, 0x00, inkwell::debug_info::DIFlags::PUBLIC)
.unwrap()
.create_basic_type("_", 0_u64, 0x00, inkwell::debug_info::DIFlags::PUBLIC)?
.as_type(),
),
&[],
@@ -967,7 +965,7 @@ pub fn gen_func_impl<
// after static analysis, only void functions can have no return at the end.
if !code_gen_context.is_terminated() {
code_gen_context.builder.build_return(None).unwrap();
code_gen_context.builder.build_return(None)?;
}
code_gen_context.builder.unset_current_debug_location();
@@ -979,7 +977,6 @@ pub fn gen_func_impl<
/// Generates LLVM IR for a function.
///
/// * `context` - The [`CoreContext`] we are inserting into.
/// * `builder` - The [`Builder`] used for generating LLVM IR.
/// * `generator` - The [`CodeGenerator`] for generating various program constructs.
/// * `registry` - The [`WorkerRegistry`] responsible for monitoring this function generation task.
/// * `task` - The [`CodeGenTask`] associated with this function generation task.
@@ -990,26 +987,24 @@ pub fn gen_func<'ctx, 'a, G: CodeGenerator>(
registry: &WorkerRegistry,
task: CodeGenTask,
unifier_cache: &mut [OnceCell<Unifier>],
) -> Result<FunctionValue<'ctx>, String> {
) -> anyhow::Result<FunctionValue<'ctx>> {
let body = task.body.clone();
gen_func_impl(context, builder, generator, registry, task, unifier_cache, |generator, ctx| {
generator.gen_block(ctx, body.iter())
})
}
#[must_use]
pub fn bool_to_i1<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
bool_to_int_type(ctx.builder, bool_value, ctx.i1)
}
#[must_use]
pub fn bool_to_i8<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
bool_to_int_type(ctx.builder, bool_value, ctx.i8)
}
@@ -1023,18 +1018,21 @@ fn bool_to_int_type<'ctx>(
builder: &Builder<'ctx>,
value: IntValue<'ctx>,
ty: IntType<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
// i1 -> i1 : %value ; no-op
// i1 -> i<N> : zext i1 %value to i<N> ; guaranteed to be 0 or 1 - see docs
// i<M> -> i<N>: zext i1 (icmp eq i<M> %value, 0) to i<N> ; same as i<M> -> i1 -> i<N>
match (value.get_type().get_bit_width(), ty.get_bit_width()) {
(1, 1) => value,
(1, _) => builder.build_int_z_extend(value, ty, "frombool").unwrap(),
(1, 1) => Ok(value),
(1, _) => Ok(builder.build_int_z_extend(value, ty, "frombool")?),
_ => bool_to_int_type(
builder,
builder
.build_int_compare(IntPredicate::NE, value, value.get_type().const_zero(), "tobool")
.unwrap(),
builder.build_int_compare(
IntPredicate::NE,
value,
value.get_type().const_zero(),
"tobool",
)?,
ty,
),
}
@@ -1060,21 +1058,12 @@ fn gen_in_range_check<'ctx>(
value: IntValue<'ctx>,
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign =
ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.i32.const_zero(), "").unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
) -> anyhow::Result<IntValue<'ctx>> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.i32.const_zero(), "")?;
let lo = ctx.builder.build_select(sign, value, stop, "").map(BasicValueEnum::into_int_value)?;
let hi = ctx.builder.build_select(sign, stop, value, "").map(BasicValueEnum::into_int_value)?;
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
Ok(ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")?)
}
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
@@ -1087,13 +1076,13 @@ pub fn type_aligned_alloca<'ctx>(
align_ty: impl Into<BasicTypeEnum<'ctx>>,
size: IntValue<'ctx>,
name: Option<&'static str>,
) -> PointerValue<'ctx> {
) -> anyhow::Result<PointerValue<'ctx>> {
/// Round `val` up to its modulo `power_of_two`.
fn round_up<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>,
) -> IntValue<'ctx> {
) -> anyhow::Result<IntValue<'ctx>> {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width(),
@@ -1105,20 +1094,18 @@ pub fn type_aligned_alloca<'ctx>(
let llvm_val_t = val.get_type();
let max_rem =
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")?;
Ok(ctx.builder.build_and(
ctx.builder.build_int_add(val, max_rem, "")?,
ctx.builder.build_not(max_rem, "")?,
"",
)?)
}
let llvm_usize = ctx.size_t;
let align_ty: BasicTypeEnum<'ctx> = align_ty.into();
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "")?;
debug_assert_eq!(
size.get_type().get_bit_width(),
@@ -1129,32 +1116,30 @@ pub fn type_aligned_alloca<'ctx>(
);
let alignment = align_ty.get_alignment();
let alignment = ctx.builder.build_int_truncate_or_bit_cast(alignment, llvm_usize, "").unwrap();
let alignment = ctx.builder.build_int_truncate_or_bit_cast(alignment, llvm_usize, "")?;
if ctx.registry.codegen_options.debug {
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None)?;
ctx.make_assert(
ctx.builder
.build_int_compare(
IntPredicate::EQ,
alignment_bitcount,
alignment_bitcount.get_type().const_int(1, false),
"",
)
.unwrap(),
ctx.builder.build_int_compare(
IntPredicate::EQ,
alignment_bitcount,
alignment_bitcount.get_type().const_int(1, false),
"",
)?,
"0:AssertionError",
"Expected power-of-two alignment for aligned_alloca, got {0}",
[Some(alignment), None, None],
ctx.current_loc,
);
)?;
}
let buffer_size = round_up(ctx, size, alignment);
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
let buffer_size = round_up(ctx, size, alignment)?;
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "")?;
// Just to be absolutely sure, alloca in [i8 x alignment] slices
gen_dyn_array_var(ctx, align_ty, aligned_slices, name).value.0
Ok(gen_dyn_array_var(ctx, align_ty, aligned_slices, name)?.value.0)
}
/// Contains all global LLVM state that is attached to an LLVM [`Module`] and independent

View File

@@ -27,7 +27,7 @@ pub fn gen_ndarray_empty<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@@ -38,10 +38,10 @@ pub fn gen_ndarray_empty<'ctx>(
let llvm_dtype = context.get_llvm_type(dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg));
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg))?;
let ndarray =
NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_empty(context, shape, None);
NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_empty(context, shape, None)?;
Ok(ndarray.value)
}
@@ -51,7 +51,7 @@ pub fn gen_ndarray_zeros<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@@ -62,10 +62,10 @@ pub fn gen_ndarray_zeros<'ctx>(
let llvm_dtype = context.get_llvm_type(dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg));
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg))?;
let ndarray = NDArrayType::new(context, llvm_dtype, ndims)
.construct_numpy_zeros(context, dtype, shape, None);
.construct_numpy_zeros(context, dtype, shape, None)?;
Ok(ndarray.value)
}
@@ -75,7 +75,7 @@ pub fn gen_ndarray_ones<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@@ -86,10 +86,10 @@ pub fn gen_ndarray_ones<'ctx>(
let llvm_dtype = context.get_llvm_type(dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg));
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg))?;
let ndarray = NDArrayType::new(context, llvm_dtype, ndims)
.construct_numpy_ones(context, dtype, shape, None);
.construct_numpy_ones(context, dtype, shape, None)?;
Ok(ndarray.value)
}
@@ -99,7 +99,7 @@ pub fn gen_ndarray_full<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
@@ -112,14 +112,14 @@ pub fn gen_ndarray_full<'ctx>(
let llvm_dtype = context.get_llvm_type(dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg));
let shape = parse_numpy_int_sequence(context, (shape_ty, shape_arg))?;
let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full(
context,
shape,
fill_value_arg,
None,
);
)?;
Ok(ndarray.value)
}
@@ -128,7 +128,7 @@ pub fn gen_ndarray_array<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
@@ -141,7 +141,7 @@ pub fn gen_ndarray_array<'ctx>(
let copy_ty = fun.0.args[1].ty;
arg.1.clone().to_basic_value_enum(context, copy_ty)?
} else {
context.gen_symbol_val(fun.0.args[1].default_value.as_ref().unwrap(), fun.0.args[1].ty)
context.gen_symbol_val(fun.0.args[1].default_value.as_ref().unwrap(), fun.0.args[1].ty)?
};
// The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be
@@ -149,9 +149,9 @@ pub fn gen_ndarray_array<'ctx>(
let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
let copy = bool_to_i1(context, copy_arg.into_int_value());
let ndarray = NDArrayValue::construct_from(context, (obj_ty, obj_arg), copy, None)
.atleast_nd(context, ndims);
let copy = bool_to_i1(context, copy_arg.into_int_value())?;
let ndarray = NDArrayValue::construct_from(context, (obj_ty, obj_arg), copy, None)?
.atleast_nd(context, ndims)?;
Ok(ndarray.value)
}
@@ -162,7 +162,7 @@ pub fn gen_ndarray_eye<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
@@ -184,7 +184,7 @@ pub fn gen_ndarray_eye<'ctx>(
{
arg.1.clone().to_basic_value_enum(context, offset_ty)
} else {
Ok(context.gen_symbol_val(fun.0.args[2].default_value.as_ref().unwrap(), offset_ty))
context.gen_symbol_val(fun.0.args[2].default_value.as_ref().unwrap(), offset_ty)
}?;
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
@@ -192,21 +192,24 @@ pub fn gen_ndarray_eye<'ctx>(
let llvm_usize = context.size_t;
let llvm_dtype = context.get_llvm_type(dtype);
let nrows = context
.builder
.build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "")
.unwrap();
let ncols = context
.builder
.build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "")
.unwrap();
let offset = context
.builder
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
.unwrap();
let nrows = context.builder.build_int_s_extend_or_bit_cast(
nrows_arg.into_int_value(),
llvm_usize,
"",
)?;
let ncols = context.builder.build_int_s_extend_or_bit_cast(
ncols_arg.into_int_value(),
llvm_usize,
"",
)?;
let offset = context.builder.build_int_s_extend_or_bit_cast(
offset_arg.into_int_value(),
llvm_usize,
"",
)?;
let ndarray = NDArrayType::new(context, llvm_dtype, 2)
.construct_numpy_eye(context, dtype, nrows, ncols, offset, None);
.construct_numpy_eye(context, dtype, nrows, ncols, offset, None)?;
Ok(ndarray.value)
}
@@ -216,7 +219,7 @@ pub fn gen_ndarray_identity<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@@ -228,12 +231,10 @@ pub fn gen_ndarray_identity<'ctx>(
let llvm_usize = context.size_t;
let llvm_dtype = context.get_llvm_type(dtype);
let n = context
.builder
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
.unwrap();
let ndarray =
NDArrayType::new(context, llvm_dtype, 2).construct_numpy_identity(context, dtype, n, None);
let n =
context.builder.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")?;
let ndarray = NDArrayType::new(context, llvm_dtype, 2)
.construct_numpy_identity(context, dtype, n, None)?;
Ok(ndarray.value)
}
@@ -243,7 +244,7 @@ pub fn gen_ndarray_copy<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<PointerValue<'ctx>, String> {
) -> anyhow::Result<PointerValue<'ctx>> {
assert!(obj.is_some());
assert!(args.is_empty());
@@ -252,7 +253,7 @@ pub fn gen_ndarray_copy<'ctx>(
let this = NDArrayType::from_unifier_type(context, this_ty)
.map_value(this_arg.into_pointer_value(), None);
let ndarray = this.make_copy(context);
let ndarray = this.make_copy(context)?;
Ok(ndarray.value)
}
@@ -262,7 +263,7 @@ pub fn gen_ndarray_fill<'ctx>(
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
) -> Result<(), String> {
) -> anyhow::Result<()> {
assert!(obj.is_some());
assert_eq!(args.len(), 1);
@@ -273,7 +274,7 @@ pub fn gen_ndarray_fill<'ctx>(
let this = NDArrayType::from_unifier_type(context, this_ty)
.map_value(this_arg.into_pointer_value(), None);
this.fill(context, value_arg);
this.fill(context, value_arg)?;
Ok(())
}
@@ -287,7 +288,7 @@ pub fn ndarray_dot<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
) -> anyhow::Result<BasicValueEnum<'ctx>> {
const FN_NAME: &str = "ndarray_dot";
match (x1, x2) {
@@ -301,22 +302,21 @@ pub fn ndarray_dot<'ctx>(
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
// Check shapes.
let a_size = a.size(ctx);
let b_size = b.size(ctx);
let same_shape =
ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap();
let a_size = a.size(ctx)?;
let b_size = b.size(ctx)?;
let same_shape = ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "")?;
ctx.make_assert(
same_shape,
"0:ValueError",
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
[Some(a_size), Some(b_size), None],
ctx.current_loc,
);
)?;
let dtype_llvm = ctx.get_llvm_type(common_dtype);
let result = gen_var(ctx, dtype_llvm, Some("np_dot_result"));
typed_store(ctx.builder, result, dtype_llvm.const_zero());
let result = gen_var(ctx, dtype_llvm, Some("np_dot_result"))?;
typed_store(ctx.builder, result, dtype_llvm.const_zero())?;
// Do dot product.
gen_for_callback(
@@ -324,32 +324,32 @@ pub fn ndarray_dot<'ctx>(
ctx,
Some("np_dot"),
|(), ctx| {
let a_iter = NDIterValue::new(ctx, a);
let b_iter = NDIterValue::new(ctx, b);
let a_iter = NDIterValue::new(ctx, a)?;
let b_iter = NDIterValue::new(ctx, b)?;
Ok((a_iter, b_iter))
},
|(), ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_element(ctx))
a_iter.has_element(ctx)
},
|(), ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(ctx);
let b_scalar = b_iter.get_scalar(ctx);
let a_scalar = a_iter.get_scalar(ctx)?;
let b_scalar = b_iter.get_scalar(ctx)?;
let old_result = ctx.builder.build_load(result, "").unwrap();
let old_result = ctx.builder.build_load(result, "")?;
let new_result: BasicValueEnum<'ctx> = match old_result {
BasicValueEnum::IntValue(old_result) => {
let a_scalar = a_scalar.into_int_value();
let b_scalar = b_scalar.into_int_value();
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "")?;
ctx.builder.build_int_add(old_result, x, "")?.into()
}
BasicValueEnum::FloatValue(old_result) => {
let a_scalar = a_scalar.into_float_value();
let b_scalar = b_scalar.into_float_value();
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "")?;
ctx.builder.build_float_add(old_result, x, "")?.into()
}
_ => {
@@ -357,27 +357,26 @@ pub fn ndarray_dot<'ctx>(
}
};
typed_store(ctx.builder, result, new_result);
typed_store(ctx.builder, result, new_result)?;
Ok(())
},
|(), ctx, (a_iter, b_iter)| {
a_iter.next(ctx);
b_iter.next(ctx);
a_iter.next(ctx)?;
b_iter.next(ctx)?;
Ok(())
},
|(), _| Ok(()),
)
.unwrap();
)?;
Ok(ctx.builder.build_load(result, "").unwrap())
Ok(ctx.builder.build_load(result, "")?)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
Ok(ctx.builder.build_int_mul(e1, e2, "")?.as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
Ok(ctx.builder.build_float_mul(e1, e2, "")?.as_basic_value_enum())
}
_ => codegen_unreachable!(

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ use std::{
sync::Arc,
};
use anyhow::anyhow;
use function_name::named;
use indexmap::IndexMap;
use indoc::indoc;
@@ -47,7 +48,7 @@ impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
) -> anyhow::Result<Option<crate::symbol_resolver::SymbolValue>> {
unimplemented!()
}
@@ -57,24 +58,24 @@ impl SymbolResolver for Resolver {
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
) -> anyhow::Result<Type> {
self.id_to_type.get(&str).copied().ok_or_else(|| anyhow!("cannot find symbol `{str}`"))
}
fn get_symbol_value<'ctx>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> {
) -> anyhow::Result<Option<ValueEnum<'ctx>>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, Vec<anyhow::Error>> {
self.id_to_def
.read()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
.ok_or_else(|| vec![anyhow!("cannot find symbol `{id}`")])
}
fn get_string_id(&self, _: &str) -> i32 {

View File

@@ -1,3 +1,4 @@
use anyhow::anyhow;
use inkwell::{
AddressSpace, IntPredicate,
types::{BasicType, BasicTypeEnum},
@@ -20,7 +21,7 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
) -> anyhow::Result<PointerValue<'ctx>>;
/// Returns the pointer to the data at the `idx`-th index. Raise an error
/// if the index is out of bounds.
@@ -29,7 +30,7 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
) -> anyhow::Result<PointerValue<'ctx>>;
/// Loads the value at the `idx`-th index without bounds checking.
fn get_unchecked<V: TryFrom<BasicValueEnum<'ctx>, Error: core::fmt::Debug>>(
@@ -37,9 +38,11 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &Index,
name: Option<&str>,
) -> V {
let ptr = self.ptr_offset_unchecked(ctx, idx, name);
typed_load(ctx.builder, ptr, self.item_type(), name.unwrap_or_default()).try_into().unwrap()
) -> anyhow::Result<V> {
let ptr = self.ptr_offset_unchecked(ctx, idx, name)?;
typed_load(ctx.builder, ptr, self.item_type(), name.unwrap_or_default())?
.try_into()
.map_err(|e| anyhow!("{e:?}"))
}
/// Loads the value at the `idx`-th index with bounds checking.
@@ -48,9 +51,11 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &Index,
name: Option<&str>,
) -> V {
let ptr = self.ptr_offset(ctx, idx, name);
typed_load(ctx.builder, ptr, self.item_type(), name.unwrap_or_default()).try_into().unwrap()
) -> anyhow::Result<V> {
let ptr = self.ptr_offset(ctx, idx, name)?;
typed_load(ctx.builder, ptr, self.item_type(), name.unwrap_or_default())?
.try_into()
.map_err(|e| anyhow!("{e:?}"))
}
/// Stores the `value` at the `idx`-th index without bounds checking.
@@ -60,9 +65,10 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
idx: &Index,
value: V,
name: Option<&str>,
) {
let ptr = self.ptr_offset_unchecked(ctx, idx, name);
typed_store(ctx.builder, ptr, value.as_basic_value_enum());
) -> anyhow::Result<()> {
let ptr = self.ptr_offset_unchecked(ctx, idx, name)?;
typed_store(ctx.builder, ptr, value.as_basic_value_enum())?;
Ok(())
}
/// Stores the `value` at the `idx`-th index with bounds checking.
@@ -72,9 +78,10 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>> {
idx: &Index,
value: V,
name: Option<&str>,
) {
let ptr = self.ptr_offset(ctx, idx, name);
typed_store(ctx.builder, ptr, value.as_basic_value_enum());
) -> anyhow::Result<()> {
let ptr = self.ptr_offset(ctx, idx, name)?;
typed_store(ctx.builder, ptr, value.as_basic_value_enum())?;
Ok(())
}
}
@@ -101,12 +108,17 @@ impl<'ctx> ArraySliceValue<'ctx> {
}
/// Copies data from the source pointer into this array slice.
pub fn memcpy_from(&self, ctx: &mut CodeGenContext<'ctx, '_>, src: PointerValue<'ctx>) {
pub fn memcpy_from(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
src: PointerValue<'ctx>,
) -> anyhow::Result<()> {
let size = ctx.sizeof(self.ty.item_ty);
let size = ctx.size_t.const_int(size, false);
let align = ctx.target.get_target_data().get_abi_alignment(&self.ty.item_ty);
let bytes = ctx.builder.build_int_mul(self.value.1, size, "").unwrap();
ctx.builder.build_memcpy(self.value.0, align, src, align, bytes).unwrap();
let bytes = ctx.builder.build_int_mul(self.value.1, size, "")?;
ctx.builder.build_memcpy(self.value.0, align, src, align, bytes)?;
Ok(())
}
}
@@ -120,20 +132,17 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
) -> anyhow::Result<PointerValue<'ctx>> {
let var_name = name.or(self.name).map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
let ptr = ctx
.builder
.build_pointer_cast(
self.value.0,
self.ty.item_ty.ptr_type(AddressSpace::default()),
"",
)
.unwrap();
let r = ctx.builder.build_in_bounds_gep(ptr, &[*idx], var_name.as_str()).unwrap();
ctx.builder.build_pointer_cast(r, ctx.ptr, name.unwrap_or("")).unwrap()
let ptr = ctx.builder.build_pointer_cast(
self.value.0,
self.ty.item_ty.ptr_type(AddressSpace::default()),
"",
)?;
let r = ctx.builder.build_in_bounds_gep(ptr, &[*idx], var_name.as_str())?;
Ok(ctx.builder.build_pointer_cast(r, ctx.ptr, name.unwrap_or(""))?)
}
}
@@ -142,18 +151,18 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
) -> anyhow::Result<PointerValue<'ctx>> {
debug_assert_eq!(idx.get_type(), ctx.size_t);
let size = self.value.1;
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "")?;
ctx.make_assert(
in_range,
"0:IndexError",
"index {0} is out of bounds for size {1}",
[Some(*idx), Some(size), None],
ctx.current_loc,
);
)?;
self.ptr_offset_unchecked(ctx, idx, name)
}

View File

@@ -59,22 +59,19 @@ impl<'ctx> ListType<'ctx> {
}
/// Allocates a new list with the given length.
#[must_use]
pub fn construct(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: Option<&'static str>,
) -> ListValue<'ctx> {
let list = self.alloca(ctx, name);
) -> anyhow::Result<ListValue<'ctx>> {
let list = self.alloca(ctx, name)?;
let len = ctx.builder.build_int_z_extend(len, ctx.size_t, "").unwrap();
list.store(ctx, field!(len), len);
let len = ctx.builder.build_int_z_extend(len, ctx.size_t, "")?;
list.store(ctx, field!(len), len)?;
let len_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, len, ctx.size_t.const_zero(), "")
.unwrap();
let len_eqz =
ctx.builder.build_int_compare(IntPredicate::EQ, len, ctx.size_t.const_zero(), "")?;
let null = ctx.ptr.const_null();
let data = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(self.item_ty) {
@@ -86,17 +83,17 @@ impl<'ctx> ListType<'ctx> {
"Cannot allocate a non-empty list with unknown element type",
[None, None, None],
ctx.current_loc,
);
)?;
}
null
} else {
let ty = ctx.get_llvm_type(self.item_ty);
let array = gen_dyn_array_var(ctx, ty, len, None).value.0;
ctx.builder.build_select(len_eqz, null, array, "").unwrap().into_pointer_value()
let array = gen_dyn_array_var(ctx, ty, len, None)?.value.0;
ctx.builder.build_select(len_eqz, null, array, "")?.into_pointer_value()
};
list.store(ctx, field!(items), data);
list
list.store(ctx, field!(items), data)?;
Ok(list)
}
}
@@ -104,7 +101,10 @@ pub type ListValue<'ctx> = Value<'ctx, ListType<'ctx>>;
impl<'ctx> ListValue<'ctx> {
/// Returns the data of this list as an [`ArraySliceValue`].
pub fn data(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> ArraySliceValue<'ctx> {
pub fn data(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<ArraySliceValue<'ctx>> {
let item_ty = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(self.ty.item_ty)
{
// Use a placeholder type.
@@ -113,12 +113,12 @@ impl<'ctx> ListValue<'ctx> {
ctx.get_llvm_type(self.ty.item_ty)
};
ArraySliceValue::new(
Ok(ArraySliceValue::new(
item_ty,
self.load(ctx, field!(items)),
self.load(ctx, field!(len)),
self.load(ctx, field!(items))?,
self.load(ctx, field!(len))?,
self.name,
)
))
}
/// Creates an empty list with the given item type.
@@ -129,7 +129,7 @@ impl<'ctx> ListValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
item_ty: Type,
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
let list_ty = ListType::new(ctx, item_ty);
list_ty.construct(ctx, ctx.size_t.const_zero(), name)
}

View File

@@ -60,7 +60,7 @@ use impl_proxy_type;
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// list: &ListValue<'ctx>)
/// -> IntValue<'ctx> {
/// list.load(ctx, field!(len))
/// list.load(ctx, field!(len)).unwrap()
/// }
/// ```
#[doc(hidden)]
@@ -110,14 +110,14 @@ pub trait ProxyTypeExt {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'static str>,
) -> Value<'ctx, Self>
) -> anyhow::Result<Value<'ctx, Self>>
where
Self: RefType<'ctx> + Copy,
{
let alloca = self.alloca_ty(ctx);
let ptr = gen_var(ctx, alloca, name);
let ptr = ctx.builder.build_pointer_cast(ptr, ctx.ptr, "ptr_cast").unwrap();
Value { ty: *self, value: ptr, name }
let ptr = gen_var(ctx, alloca, name)?;
let ptr = ctx.builder.build_pointer_cast(ptr, ctx.ptr, "ptr_cast")?;
Ok(Value { ty: *self, value: ptr, name })
}
/// Maps an existing value of the underlying LLVM type to a typed value.
@@ -163,7 +163,7 @@ impl<'ctx, T: ProxyTypeMarker<'ctx, Value = PointerValue<'ctx>>> Value<'ctx, T>
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
field: impl FnOnce(&T) -> StructField<'ctx, B>,
) -> B
) -> anyhow::Result<B>
where
T: RefType<'ctx>,
B: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
@@ -180,11 +180,12 @@ impl<'ctx, T: ProxyTypeMarker<'ctx, Value = PointerValue<'ctx>>> Value<'ctx, T>
ctx: &mut CodeGenContext<'ctx, '_>,
field: impl FnOnce(&T) -> StructField<'ctx, B>,
value: B,
) where
) -> anyhow::Result<()>
where
T: RefType<'ctx>,
B: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error: std::fmt::Debug>,
{
let struct_ty = self.ty.alloca_ty(ctx);
field(&self.ty).store(ctx, struct_ty, self.value, value, self.name);
field(&self.ty).store(ctx, struct_ty, self.value, value, self.name)
}
}

View File

@@ -38,30 +38,30 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
(list_ty, list): (Type, ListValue<'ctx>),
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list_ty);
// Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = ctx.size_t.const_int(ndims_int, false);
let shape = gen_array_var(ctx, ctx.size_t, ndims_int, None);
let shape = gen_array_var(ctx, ctx.size_t, ndims_int, None)?;
let fn_name = get_usize_dependent_function_name(
ctx,
"__nac3_ndarray_array_set_and_validate_list_shape",
);
call_extern!(ctx: void _ = fn_name(list.value, ndims, shape.value.0));
call_extern!(ctx: void _ = fn_name(list.value, ndims, shape.value.0))?;
let ndarray = NDArrayType::new(ctx, dtype, ndims_int).construct(ctx, name);
ndarray.shape(ctx).memcpy_from(ctx, shape.value.0);
ndarray.create_data(ctx);
let ndarray = NDArrayType::new(ctx, dtype, ndims_int).construct(ctx, name)?;
ndarray.shape(ctx)?.memcpy_from(ctx, shape.value.0)?;
ndarray.create_data(ctx)?;
// Copy all contents from the list.
let fn_name =
get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array");
call_extern!(ctx: void _ = fn_name(list.value, ndarray.value));
call_extern!(ctx: void _ = fn_name(list.value, ndarray.value))?;
ndarray
Ok(ndarray)
}
/// Implementation of `np_array(<list>, copy=None)`
@@ -69,7 +69,7 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
(list_ty, list): (Type, ListValue<'ctx>),
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
// np_array without copying is only possible `list` is not nested.
//
// If `list` is `list[T]`, we can create an ndarray with `data` set
@@ -82,16 +82,16 @@ impl<'ctx> NDArrayValue<'ctx> {
// `list` is not nested
assert_eq!(ndims, 1);
let ndarray = NDArrayType::new(ctx, dtype, 1).construct(ctx, name);
let ndarray = NDArrayType::new(ctx, dtype, 1).construct(ctx, name)?;
let (data, len) = list.data(ctx).value;
ndarray.store(ctx, field!(data), data);
let (data, len) = list.data(ctx)?.value;
ndarray.store(ctx, field!(data), data)?;
// ndarray->shape[0] = list->len;
ndarray.shape(ctx).set_unchecked(ctx, &ctx.size_t.const_zero(), len, None);
ndarray.shape(ctx)?.set_unchecked(ctx, &ctx.size_t.const_zero(), len, None)?;
// Set strides, the `data` is contiguous
ndarray.set_strides_contiguous(ctx);
ndarray.set_strides_contiguous(ctx)?;
ndarray
Ok(ndarray)
} else {
// `list` is nested, copy
Self::from_list_must_copy(ctx, (list_ty, list), name)
@@ -104,7 +104,7 @@ impl<'ctx> NDArrayValue<'ctx> {
(list_ty, list): (Type, ListValue<'ctx>),
copy: IntValue<'ctx>,
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
assert_eq!(copy.get_type(), ctx.i1);
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list_ty);
@@ -114,18 +114,17 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx,
|(), _ctx| Ok(copy),
|(), ctx| {
let ndarray = Self::from_list_must_copy(ctx, (list_ty, list), name);
let ndarray = Self::from_list_must_copy(ctx, (list_ty, list), name)?;
Ok(Some(ndarray.value))
},
|(), ctx| {
let ndarray = Self::from_list_maybe_copy(ctx, (list_ty, list), name);
let ndarray = Self::from_list_maybe_copy(ctx, (list_ty, list), name)?;
Ok(Some(ndarray.value))
},
)
.unwrap()
)?
.unwrap();
NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None)
Ok(NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None))
}
/// Implementation of `np_array(<ndarray>, copy=copy)`.
@@ -134,7 +133,7 @@ impl<'ctx> NDArrayValue<'ctx> {
ndarray: Self,
copy: IntValue<'ctx>,
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
assert_eq!(copy.get_type(), ctx.i1);
let ndarray_val = gen_if_else_expr_callback(
@@ -142,18 +141,17 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx,
|(), _ctx| Ok(copy),
|(), ctx| {
let ndarray = ndarray.make_copy(ctx); // Force copy
let ndarray = ndarray.make_copy(ctx)?; // Force copy
Ok(Some(ndarray.value))
},
|(), _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.value))
},
)
.unwrap()
)?
.unwrap();
ndarray.ty.map_value(ndarray_val, name)
Ok(ndarray.ty.map_value(ndarray_val, name))
}
/// Create a new ndarray like
@@ -167,7 +165,7 @@ impl<'ctx> NDArrayValue<'ctx> {
(object_ty, object): (Type, BasicValueEnum<'ctx>),
copy: IntValue<'ctx>,
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
match &*ctx.unifier.get_ty_immutable(object_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>

View File

@@ -2,7 +2,6 @@ use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
};
use itertools::Itertools as _;
use nac3core_derive::{ProxyType, StructFields};
use crate::codegen::{
@@ -52,23 +51,22 @@ impl<'ctx> NDArrayValue<'ctx> {
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
/// The caller has to figure this out for this function.
/// * `target_shape` - An array pointer pointing to the target shape.
#[must_use]
pub fn broadcast_to(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
target_ndims: u64,
target_shape: ArraySliceValue<'ctx>,
) -> Self {
) -> anyhow::Result<Self> {
assert!(self.ty.ndims <= target_ndims);
assert_eq!(target_shape.ty.item_ty, ctx.size_t.into());
let broadcast_ndarray =
NDArrayType::new(ctx, self.ty.dtype, target_ndims).construct(ctx, None);
broadcast_ndarray.shape(ctx).memcpy_from(ctx, target_shape.value.0);
NDArrayType::new(ctx, self.ty.dtype, target_ndims).construct(ctx, None)?;
broadcast_ndarray.shape(ctx)?.memcpy_from(ctx, target_shape.value.0)?;
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to");
call_extern!(ctx: void _ = name(self.value, broadcast_ndarray.value));
broadcast_ndarray
call_extern!(ctx: void _ = name(self.value, broadcast_ndarray.value))?;
Ok(broadcast_ndarray)
}
}
@@ -97,34 +95,36 @@ pub struct BroadcastAllResult<'ctx> {
pub fn broadcast<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: &[NDArrayValue<'ctx>],
) -> BroadcastAllResult<'ctx> {
) -> anyhow::Result<BroadcastAllResult<'ctx>> {
let shape_entry_ty = ShapeEntryType::new(ctx);
let shape_entries = ctx.size_t.const_int(ndarrays.len() as _, false);
let arr = gen_array_var(ctx, shape_entry_ty.inner.llvm_ty, ndarrays.len() as _, None);
let arr = gen_array_var(ctx, shape_entry_ty.inner.llvm_ty, ndarrays.len() as _, None)?;
// Store shapes into memory.
for (i, ndarray) in ndarrays.iter().enumerate() {
let idx = ctx.size_t.const_int(i as _, false);
let pshape_entry = arr.ptr_offset_unchecked(ctx, &idx, None);
let pshape_entry = arr.ptr_offset_unchecked(ctx, &idx, None)?;
let shape_entry = shape_entry_ty.map_value(pshape_entry, None);
let ndims = ndarray.ty.ndims_val(ctx);
let shape = ndarray.shape(ctx).value.0;
shape_entry.store(ctx, field!(ndims), ndims);
shape_entry.store(ctx, field!(shape), shape);
let shape = ndarray.shape(ctx)?.value.0;
shape_entry.store(ctx, field!(ndims), ndims)?;
shape_entry.store(ctx, field!(shape), shape)?;
}
let ndims = ndarrays.iter().map(|ndarray| ndarray.ty.ndims).max().unwrap();
let new_shape_ptr = gen_array_var(ctx, ctx.size_t, ndims, None).value.0;
let new_shape_ptr = gen_array_var(ctx, ctx.size_t, ndims, None)?.value.0;
let ndims_v = ctx.size_t.const_int(ndims, false);
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes");
call_extern!(ctx: void _ = name(shape_entries, arr.value.0, ndims_v, new_shape_ptr));
call_extern!(ctx: void _ = name(shape_entries, arr.value.0, ndims_v, new_shape_ptr))?;
// Now this new shape is initialized.
let new_shape = ArraySliceValue::new(ctx.size_t.into(), new_shape_ptr, ndims_v, None);
let new_ndarrays =
ndarrays.iter().map(|ndarray| ndarray.broadcast_to(ctx, ndims, new_shape)).collect_vec();
BroadcastAllResult { ndims, shape: new_shape, ndarrays: new_ndarrays }
let new_ndarrays = ndarrays
.iter()
.map(|ndarray| ndarray.broadcast_to(ctx, ndims, new_shape))
.collect::<anyhow::Result<Vec<_>>>()?;
Ok(BroadcastAllResult { ndims, shape: new_shape, ndarrays: new_ndarrays })
}
/// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping`
@@ -141,19 +141,19 @@ pub fn broadcast_starmap<'ctx, 'a, MappingFn>(
ndarrays: &[NDArrayValue<'ctx>],
out: NDArrayOut<'ctx>,
mapping: MappingFn,
) -> Result<NDArrayValue<'ctx>, String>
) -> anyhow::Result<NDArrayValue<'ctx>>
where
MappingFn: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
&[BasicValueEnum<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
) -> anyhow::Result<BasicValueEnum<'ctx>>,
{
// Broadcast inputs
let broadcast_result = broadcast(ctx, ndarrays);
let out_ndarray = out.resolve(ctx, broadcast_result.ndims, broadcast_result.shape);
let broadcast_result = broadcast(ctx, ndarrays)?;
let out_ndarray = out.resolve(ctx, broadcast_result.ndims, broadcast_result.shape)?;
// Map element-wise and store results into `mapped_ndarray`.
let nditer = NDIterValue::new(ctx, out_ndarray);
let nditer = NDIterValue::new(ctx, out_ndarray)?;
gen_for_callback(
&mut (),
ctx,
@@ -164,31 +164,34 @@ where
.ndarrays
.iter()
.map(|ndarray| NDIterValue::new(ctx, *ndarray))
.collect_vec();
.collect::<anyhow::Result<Vec<_>>>()?;
Ok((nditer, other_nditers))
},
|(), ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_element()`.
// `in_nditers`' `has_element()`s should return the same value.
Ok(out_nditer.has_element(ctx))
out_nditer.has_element(ctx)
},
|(), ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
// and write to `out_ndarray`.
let in_scalars = in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec();
let in_scalars = in_nditers
.iter()
.map(|nditer| nditer.get_scalar(ctx))
.collect::<anyhow::Result<Vec<_>>>()?;
let result = mapping(ctx, &in_scalars)?;
let p = out_nditer.curr_ptr(ctx);
typed_store(ctx.builder, p, result);
let p = out_nditer.curr_ptr(ctx)?;
typed_store(ctx.builder, p, result)?;
Ok(())
},
|(), ctx, (out_nditer, in_nditers)| {
// Advance all iterators
out_nditer.next(ctx);
out_nditer.next(ctx)?;
for nditer in &in_nditers {
nditer.next(ctx);
nditer.next(ctx)?;
}
Ok(())
},
@@ -228,12 +231,12 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
inputs: &[Self],
ret_dtype: BasicTypeEnum<'ctx>,
mapping: MappingFn,
) -> Result<Self, String>
) -> anyhow::Result<Self>
where
MappingFn: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
&[BasicValueEnum<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
) -> anyhow::Result<BasicValueEnum<'ctx>>,
{
// Check if all inputs are Scalars
let all_scalars: Option<Vec<_>> = inputs
@@ -249,7 +252,10 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
Ok(ScalarOrNDArray::Scalar(value))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.to_ndarray(ctx)).collect_vec();
let inputs = inputs
.iter()
.map(|input| input.to_ndarray(ctx))
.collect::<anyhow::Result<Vec<_>>>()?;
let ret = NDArrayOut::NewNDArray { dtype: ret_dtype };
let ndarray = broadcast_starmap(ctx, &inputs, ret, mapping)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
@@ -264,12 +270,12 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, 'a>,
out: NDArrayOut<'ctx>,
mapping: Mapping,
) -> Result<Self, String>
) -> anyhow::Result<Self>
where
Mapping: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> anyhow::Result<BasicValueEnum<'ctx>>,
{
broadcast_starmap(ctx, &[*self], out, |ctx, scalars| mapping(ctx, scalars[0]))
}
@@ -288,12 +294,12 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: BasicTypeEnum<'ctx>,
mapping: Mapping,
) -> Result<Self, String>
) -> anyhow::Result<Self>
where
Mapping: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> anyhow::Result<BasicValueEnum<'ctx>>,
{
ScalarOrNDArray::broadcasting_starmap(ctx, &[*self], ret_dtype, |ctx, scalars| {
mapping(ctx, scalars[0])

View File

@@ -40,44 +40,43 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn make_contiguous_ndarray(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ContiguousNDArrayValue<'ctx> {
) -> anyhow::Result<ContiguousNDArrayValue<'ctx>> {
let result = ContiguousNDArrayType {
inner: BuiltinStruct::new(ctx, "contiguous_ndarray"),
dtype: self.ty.dtype,
ndims: self.ty.ndims,
};
let result = result.alloca(ctx, self.name);
let result = result.alloca(ctx, self.name)?;
// Set ndims and shape.
let ndims = self.ty.ndims_val(ctx);
result.store(ctx, field!(ndims), ndims);
result.store(ctx, field!(ndims), ndims)?;
let shape = self.load(ctx, field!(shape));
result.store(ctx, field!(shape), shape);
let shape = self.load(ctx, field!(shape))?;
result.store(ctx, field!(shape), shape)?;
gen_if_callback(
&mut (),
ctx,
|(), ctx| Ok(self.is_c_contiguous(ctx)),
|(), ctx| self.is_c_contiguous(ctx),
|(), ctx| {
// This ndarray is contiguous.
let data = self.load(ctx, field!(data));
result.store(ctx, field!(data), data);
let data = self.load(ctx, field!(data))?;
result.store(ctx, field!(data), data)?;
Ok(())
},
|(), ctx| {
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
// ndarray with contiguous `data`.
let copied_ndarray = self.make_copy(ctx);
let data = copied_ndarray.load(ctx, field!(data));
result.store(ctx, field!(data), data);
let copied_ndarray = self.make_copy(ctx)?;
let data = copied_ndarray.load(ctx, field!(data))?;
result.store(ctx, field!(data), data)?;
Ok(())
},
)
.unwrap();
)?;
result
Ok(result)
}
/// Create an [`NDArrayValue`] from a [`ContiguousNDArrayValue`].
@@ -92,21 +91,21 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
carray: ContiguousNDArrayValue<'ctx>,
ndims: u64,
) -> Self {
) -> anyhow::Result<Self> {
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
// Allocate the resulting ndarray.
let ndarray = NDArrayType::new(ctx, carray.ty.dtype, ndims).construct(ctx, carray.name);
let ndarray = NDArrayType::new(ctx, carray.ty.dtype, ndims).construct(ctx, carray.name)?;
// Copy shape and update strides
let shape = carray.load(ctx, field!(shape));
ndarray.shape(ctx).memcpy_from(ctx, shape);
ndarray.set_strides_contiguous(ctx);
let shape = carray.load(ctx, field!(shape))?;
ndarray.shape(ctx)?.memcpy_from(ctx, shape)?;
ndarray.set_strides_contiguous(ctx)?;
// Share data
let data = carray.load(ctx, field!(data));
ndarray.store(ctx, field!(data), data);
let data = carray.load(ctx, field!(data))?;
ndarray.store(ctx, field!(data), data)?;
ndarray
Ok(ndarray)
}
}

View File

@@ -1,3 +1,4 @@
use anyhow::bail;
use inkwell::{
IntPredicate,
values::{BasicValueEnum, IntValue},
@@ -21,54 +22,58 @@ use crate::{
fn ndarray_zero_value<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.i32.const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.i64.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.f64.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.i1.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string("").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
) -> anyhow::Result<BasicValueEnum<'ctx>> {
Ok(
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.i32.const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.i64.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.f64.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.i1.const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string("")?.into()
} else {
bail!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
},
)
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.i32.const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.i64.const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.f64.const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.i1.const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string("1").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
) -> anyhow::Result<BasicValueEnum<'ctx>> {
Ok(
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.i32.const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.i64.const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.f64.const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.i1.const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string("1")?.into()
} else {
bail!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
},
)
}
impl<'ctx> NDArrayType<'ctx> {
@@ -79,18 +84,18 @@ impl<'ctx> NDArrayType<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
shape: ArraySliceValue<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
let ndarray = self.construct(ctx, name);
) -> anyhow::Result<NDArrayValue<'ctx>> {
let ndarray = self.construct(ctx, name)?;
// Validate `shape`
let (shape_ptr, shape_len) = shape.value;
let name =
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
call_extern!(ctx: (ctx.size_t) _ = name(shape_len, shape_ptr));
call_extern!(ctx: (ctx.size_t) _ = name(shape_len, shape_ptr))?;
ndarray.shape(ctx).memcpy_from(ctx, shape_ptr);
ndarray.create_data(ctx);
ndarray
ndarray.shape(ctx)?.memcpy_from(ctx, shape_ptr)?;
ndarray.create_data(ctx)?;
Ok(ndarray)
}
/// Create an ndarray like
@@ -101,10 +106,10 @@ impl<'ctx> NDArrayType<'ctx> {
shape: ArraySliceValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
let ndarray = self.construct_numpy_empty(ctx, shape, name);
ndarray.fill(ctx, fill_value);
ndarray
) -> anyhow::Result<NDArrayValue<'ctx>> {
let ndarray = self.construct_numpy_empty(ctx, shape, name)?;
ndarray.fill(ctx, fill_value)?;
Ok(ndarray)
}
fn assert_compatible_dtype(&self, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type) {
@@ -125,9 +130,9 @@ impl<'ctx> NDArrayType<'ctx> {
dtype: Type,
shape: ArraySliceValue<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
) -> anyhow::Result<NDArrayValue<'ctx>> {
self.assert_compatible_dtype(ctx, dtype);
let fill_value = ndarray_zero_value(ctx, dtype);
let fill_value = ndarray_zero_value(ctx, dtype)?;
self.construct_numpy_full(ctx, shape, fill_value, name)
}
@@ -139,9 +144,9 @@ impl<'ctx> NDArrayType<'ctx> {
dtype: Type,
shape: ArraySliceValue<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
) -> anyhow::Result<NDArrayValue<'ctx>> {
self.assert_compatible_dtype(ctx, dtype);
let fill_value = ndarray_one_value(ctx, dtype);
let fill_value = ndarray_one_value(ctx, dtype)?;
self.construct_numpy_full(ctx, shape, fill_value, name)
}
@@ -156,41 +161,36 @@ impl<'ctx> NDArrayType<'ctx> {
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
) -> anyhow::Result<NDArrayValue<'ctx>> {
self.assert_compatible_dtype(ctx, dtype);
assert_eq!(nrows.get_type(), ctx.size_t);
assert_eq!(ncols.get_type(), ctx.size_t);
assert_eq!(offset.get_type(), ctx.size_t);
let ndzero = ndarray_zero_value(ctx, dtype);
let ndone = ndarray_one_value(ctx, dtype);
let ndzero = ndarray_zero_value(ctx, dtype)?;
let ndone = ndarray_one_value(ctx, dtype)?;
let ndarray = self.with_shape(ctx, &[nrows, ncols], name);
let ndarray = self.with_shape(ctx, &[nrows, ncols], name)?;
ndarray
.foreach(ctx, |ctx, _, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
ndarray.foreach(ctx, |ctx, _, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
let indices = nditer.indices(ctx);
let row_i = indices.get_unchecked(ctx, &ctx.size_t.const_zero(), None);
let col_i = indices.get_unchecked(ctx, &ctx.size_t.const_int(1, false), None);
let indices = nditer.indices(ctx)?;
let row_i = indices.get_unchecked(ctx, &ctx.size_t.const_zero(), None)?;
let col_i = indices.get_unchecked(ctx, &ctx.size_t.const_int(1, false), None)?;
let with_offset = ctx.builder.build_int_add(row_i, offset, "").unwrap();
let be_one = ctx
.builder
.build_int_compare(IntPredicate::EQ, with_offset, col_i, "")
.unwrap();
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap();
let with_offset = ctx.builder.build_int_add(row_i, offset, "")?;
let be_one = ctx.builder.build_int_compare(IntPredicate::EQ, with_offset, col_i, "")?;
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value")?;
let p = nditer.curr_ptr(ctx);
typed_store(ctx.builder, p, value);
let p = nditer.curr_ptr(ctx)?;
typed_store(ctx.builder, p, value)?;
Ok(())
})
.unwrap();
Ok(())
})?;
ndarray
Ok(ndarray)
}
/// Create an ndarray like
@@ -201,7 +201,7 @@ impl<'ctx> NDArrayType<'ctx> {
dtype: Type,
size: IntValue<'ctx>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
) -> anyhow::Result<NDArrayValue<'ctx>> {
let offset = ctx.size_t.const_zero();
self.construct_numpy_eye(ctx, dtype, size, size, offset, name)
}

View File

@@ -47,19 +47,19 @@ impl<'ctx> NDIndexType<'ctx> {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
in_ndindices: &[RustNDIndex<'ctx>],
) -> ArraySliceValue<'ctx> {
) -> anyhow::Result<ArraySliceValue<'ctx>> {
// Allocate the LLVM ndindices.
let ty = self.alloca_ty(ctx);
let ndindices = gen_array_var(ctx, ty, in_ndindices.len() as u64, None);
let ndindices = gen_array_var(ctx, ty, in_ndindices.len() as u64, None)?;
// Initialize all of them.
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
let pndindex =
ndindices.ptr_offset_unchecked(ctx, &ctx.i64.const_int(i as _, false), None);
in_ndindex.write_to_ndindex(ctx, self.map_value(pndindex, None));
ndindices.ptr_offset_unchecked(ctx, &ctx.i64.const_int(i as _, false), None)?;
in_ndindex.write_to_ndindex(ctx, self.map_value(pndindex, None))?;
}
ndindices
Ok(ndindices)
}
}
@@ -105,7 +105,7 @@ impl<'ctx> SliceValue<'ctx> {
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<Self, String> {
) -> anyhow::Result<Self> {
fn write_value<'ctx>(
generator: &mut impl CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
@@ -113,21 +113,21 @@ impl<'ctx> SliceValue<'ctx> {
result: SliceValue<'ctx>,
defined: impl FnOnce(&SliceType<'ctx>) -> StructField<'ctx, IntValue<'ctx>>,
val: impl FnOnce(&SliceType<'ctx>) -> StructField<'ctx, IntValue<'ctx>>,
) -> Result<(), String> {
) -> anyhow::Result<()> {
match value_expr {
// Not defined
None => result.store(ctx, defined, ctx.i1.const_zero()),
None => result.store(ctx, defined, ctx.i1.const_zero())?,
Some(value_expr) => {
let value = generator.gen_expr(ctx, value_expr)?.to_basic_value_enum(ctx)?;
result.store(ctx, defined, ctx.i1.const_int(1, false));
result.store(ctx, val, value.into_int_value());
result.store(ctx, defined, ctx.i1.const_int(1, false))?;
result.store(ctx, val, value.into_int_value())?;
}
}
Ok(())
}
let ty = SliceType::new(ctx);
let result = ty.alloca(ctx, None);
let result = ty.alloca(ctx, None)?;
write_value(generator, ctx, lower, result, field!(start_defined), field!(start))?;
write_value(generator, ctx, upper, result, field!(stop_defined), field!(stop))?;
@@ -159,7 +159,7 @@ impl<'ctx> RustNDIndex<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<Self>, String> {
) -> anyhow::Result<Vec<Self>> {
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
@@ -218,22 +218,24 @@ impl<'ctx> RustNDIndex<'ctx> {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dst_ndindex: NDIndexValue<'ctx>,
) {
) -> anyhow::Result<()> {
// Set `dst_ndindex.type`
dst_ndindex.store(ctx, field!(type_), ctx.i8.const_int(self.get_type_id(), false));
dst_ndindex.store(ctx, field!(type_), ctx.i8.const_int(self.get_type_id(), false))?;
// Set `dst_ndindex_ptr->data`
match *self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = gen_var(ctx, ctx.i32, None);
typed_store(ctx.builder, index_ptr, in_index);
dst_ndindex.store(ctx, field!(data), index_ptr);
let index_ptr = gen_var(ctx, ctx.i32, None)?;
typed_store(ctx.builder, index_ptr, in_index)?;
dst_ndindex.store(ctx, field!(data), index_ptr)?;
}
RustNDIndex::Slice(slice) => {
dst_ndindex.store(ctx, field!(data), slice.value);
dst_ndindex.store(ctx, field!(data), slice.value)?;
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
Ok(())
}
}
@@ -259,16 +261,19 @@ impl<'ctx> NDArrayValue<'ctx> {
///
/// This function behaves like NumPy's ndarray indexing, but if the indices index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index(&self, ctx: &mut CodeGenContext<'ctx, '_>, indices: &[RustNDIndex<'ctx>]) -> Self {
pub fn index(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: &[RustNDIndex<'ctx>],
) -> anyhow::Result<Self> {
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
let dst = NDArrayType::new(ctx, self.ty.dtype, dst_ndims).construct(ctx, None);
let indices = NDIndexType::new(ctx).construct(ctx, indices);
let dst = NDArrayType::new(ctx, self.ty.dtype, dst_ndims).construct(ctx, None)?;
let indices = NDIndexType::new(ctx).construct(ctx, indices)?;
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index");
let (idx_ptr, idx_len) = indices.value;
call_extern!(ctx: void _ = name(idx_len, idx_ptr, self.value, dst.value));
call_extern!(ctx: void _ = name(idx_len, idx_ptr, self.value, dst.value))?;
dst
Ok(dst)
}
}

View File

@@ -1,3 +1,4 @@
use anyhow::anyhow;
use inkwell::{
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
@@ -56,61 +57,71 @@ pub type NDIterValue<'ctx> = Value<'ctx, NDIterType<'ctx>>;
impl<'ctx> NDIterValue<'ctx> {
/// Creates an iterator that iterates through `ndarray`.
#[must_use]
pub fn new(ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>) -> Self {
pub fn new(
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> anyhow::Result<Self> {
let ty = NDIterType {
inner: BuiltinStruct::new(ctx, "nditer"),
dtype: ndarray.ty.dtype,
ndims: ndarray.ty.ndims,
};
let nditer = ty.alloca(ctx, None);
let nditer = ty.alloca(ctx, None)?;
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices = gen_array_var(ctx, ctx.size_t, ndarray.ty.ndims, None).value.0;
let indices = gen_array_var(ctx, ctx.size_t, ndarray.ty.ndims, None)?.value.0;
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
call_extern!(ctx: void _ = name(nditer.value, ndarray.value, indices));
call_extern!(ctx: void _ = name(nditer.value, ndarray.value, indices))?;
nditer
Ok(nditer)
}
/// Advances the iterator to the next element.
pub fn next(&self, ctx: &mut CodeGenContext<'ctx, '_>) {
pub fn next(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<()> {
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next");
call_extern!(ctx: void _ = name(self.value));
call_extern!(ctx: void _ = name(self.value))?;
Ok(())
}
/// Returns whether the iterator is currently referring to a valid element.
#[must_use]
pub fn has_element(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
pub fn has_element(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<IntValue<'ctx>> {
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element");
call_extern!(ctx: (ctx.i1) _ = name(self.value))
}
/// Returns a pointer to the current element.
#[must_use]
pub fn curr_ptr(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
pub fn curr_ptr(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<PointerValue<'ctx>> {
self.load(ctx, field!(element))
}
/// Loads and returns the current element as a scalar value.
#[must_use]
pub fn get_scalar(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let p = self.curr_ptr(ctx);
pub fn get_scalar(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<BasicValueEnum<'ctx>> {
let p = self.curr_ptr(ctx)?;
typed_load(ctx.builder, p, self.ty.dtype, "value")
}
/// Returns the current iteration index (i.e., how many elements have been iterated so far).
#[must_use]
pub fn get_index(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
pub fn get_index(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<IntValue<'ctx>> {
self.load(ctx, field!(nth))
}
/// Returns the current indices in each dimension as an array.
#[must_use]
pub fn indices(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> ArraySliceValue<'ctx> {
let indices_ptr = self.load(ctx, field!(indices));
ArraySliceValue::new(ctx.size_t.into(), indices_ptr, self.ty.ndims_val(ctx), self.name)
pub fn indices(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<ArraySliceValue<'ctx>> {
let indices_ptr = self.load(ctx, field!(indices))?;
Ok(ArraySliceValue::new(ctx.size_t.into(), indices_ptr, self.ty.ndims_val(ctx), self.name))
}
}
@@ -119,23 +130,23 @@ impl<'ctx> NDArrayValue<'ctx> {
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
pub fn foreach<'a, F>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, body: F) -> Result<(), String>
pub fn foreach<'a, F>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, body: F) -> anyhow::Result<()>
where
F: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterValue<'ctx>,
) -> Result<(), String>,
) -> anyhow::Result<()>,
{
gen_for_callback(
&mut (),
ctx,
Some("ndarray_foreach"),
|(), ctx| Ok(NDIterValue::new(ctx, *self)),
|(), ctx, nditer| Ok(nditer.has_element(ctx)),
|(), ctx| NDIterValue::new(ctx, *self),
|(), ctx, nditer| nditer.has_element(ctx),
|(), ctx, hooks, nditer| body(ctx, hooks, nditer),
|(), ctx, nditer| {
nditer.next(ctx);
nditer.next(ctx)?;
Ok(())
},
|(), _| Ok(()),
@@ -153,7 +164,7 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, 'a>,
init: V,
f: F,
) -> Result<V, String>
) -> anyhow::Result<V>
where
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error: std::fmt::Debug> + Copy,
F: FnOnce(
@@ -161,33 +172,34 @@ impl<'ctx> NDArrayValue<'ctx> {
BreakContinueHooks<'ctx>,
V,
NDIterValue<'ctx>,
) -> Result<V, String>,
) -> anyhow::Result<V>,
{
let init = init.as_basic_value_enum();
let acc_ptr = gen_var(ctx, init.get_type(), None);
typed_store(ctx.builder, acc_ptr, init);
let acc_ptr = gen_var(ctx, init.get_type(), None)?;
typed_store(ctx.builder, acc_ptr, init)?;
gen_for_callback(
&mut (),
ctx,
Some("ndarray_fold"),
|(), ctx| Ok(NDIterValue::new(ctx, *self)),
|(), ctx, nditer| Ok(nditer.has_element(ctx)),
|(), ctx| NDIterValue::new(ctx, *self),
|(), ctx, nditer| nditer.has_element(ctx),
|(), ctx, hooks, nditer| {
let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap();
let acc = V::try_from(ctx.builder.build_load(acc_ptr, "")?)
.map_err(|e| anyhow!("{e:?}"))?;
let acc = f(ctx, hooks, acc, nditer)?;
typed_store(ctx.builder, acc_ptr, acc);
typed_store(ctx.builder, acc_ptr, acc)?;
Ok(())
},
|(), ctx, nditer| {
nditer.next(ctx);
nditer.next(ctx)?;
Ok(())
},
|(), _| Ok(()),
)?;
let acc = ctx.builder.build_load(acc_ptr, "").unwrap();
Ok(V::try_from(acc).unwrap())
let acc = ctx.builder.build_load(acc_ptr, "")?;
V::try_from(acc).map_err(|e| anyhow!("{e:?}"))
}
}
@@ -205,7 +217,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ctx: &mut CodeGenContext<'ctx, 'a>,
init: V,
f: F,
) -> Result<V, String>
) -> anyhow::Result<V>
where
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error: std::fmt::Debug> + Copy,
F: FnOnce(
@@ -213,12 +225,12 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
Option<&BreakContinueHooks<'ctx>>,
V,
BasicValueEnum<'ctx>,
) -> Result<V, String>,
) -> anyhow::Result<V>,
{
match self {
ScalarOrNDArray::Scalar(v) => f(ctx, None, init, *v),
ScalarOrNDArray::NDArray(v) => v.fold(ctx, init, |ctx, hooks, acc, nditer| {
let elem = nditer.get_scalar(ctx);
let elem = nditer.get_scalar(ctx)?;
f(ctx, Some(&hooks), acc, elem)
}),
}

View File

@@ -28,7 +28,7 @@ fn matmul_at_least_2d<'ctx>(
dst_dtype: Type,
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
) -> NDArrayValue<'ctx> {
) -> anyhow::Result<NDArrayValue<'ctx>> {
assert!(in_a.ty.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ty.ndims);
assert!(in_b.ty.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ty.ndims);
@@ -44,10 +44,11 @@ fn matmul_at_least_2d<'ctx>(
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
// destination ndarray to store the result of matmul.
let (lhs, rhs, dst) = {
let in_lhs_shape = in_a.shape(ctx);
let in_rhs_shape = in_b.shape(ctx);
let in_lhs_shape = in_a.shape(ctx)?;
let in_rhs_shape = in_b.shape(ctx)?;
let [lhs_shape, rhs_shape, dst_shape] =
core::array::from_fn(|_| gen_array_var(ctx, ctx.size_t, ndims_int, None));
let [lhs_shape, rhs_shape, dst_shape] = [lhs_shape?, rhs_shape?, dst_shape?];
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");
call_extern!(ctx: void _ = name(
@@ -57,17 +58,18 @@ fn matmul_at_least_2d<'ctx>(
lhs_shape.value.0,
rhs_shape.value.0,
dst_shape.value.0,
));
))?;
let lhs = in_a.broadcast_to(ctx, ndims_int, lhs_shape);
let rhs = in_b.broadcast_to(ctx, ndims_int, rhs_shape);
let lhs = in_a.broadcast_to(ctx, ndims_int, lhs_shape)?;
let rhs = in_b.broadcast_to(ctx, ndims_int, rhs_shape)?;
let dst =
NDArrayOut::NewNDArray { dtype: llvm_dst_dtype }.resolve(ctx, ndims_int, dst_shape);
NDArrayOut::NewNDArray { dtype: llvm_dst_dtype }.resolve(ctx, ndims_int, dst_shape)?;
(lhs, rhs, dst)
};
let len = lhs.shape(ctx).get_unchecked(ctx, &ctx.size_t.const_int(ndims_int - 1, false), None);
let len =
lhs.shape(ctx)?.get_unchecked(ctx, &ctx.size_t.const_int(ndims_int - 1, false), None)?;
let [at_row, at_col] = [ndims_int - 2, ndims_int - 1].map(|x| ctx.size_t.const_int(x, true));
@@ -75,13 +77,13 @@ fn matmul_at_least_2d<'ctx>(
let dst_zero = dst_dtype_llvm.const_zero();
dst.foreach(ctx, |ctx, _, hdl| {
let pdst_ij = hdl.curr_ptr(ctx);
let pdst_ij = hdl.curr_ptr(ctx)?;
typed_store(ctx.builder, pdst_ij, dst_zero);
typed_store(ctx.builder, pdst_ij, dst_zero)?;
let indices = hdl.indices(ctx);
let i = indices.get_unchecked::<IntValue<'ctx>>(ctx, &at_row, None);
let j = indices.get_unchecked::<IntValue<'ctx>>(ctx, &at_col, None);
let indices = hdl.indices(ctx)?;
let i = indices.get_unchecked::<IntValue<'ctx>>(ctx, &at_row, None)?;
let j = indices.get_unchecked::<IntValue<'ctx>>(ctx, &at_col, None)?;
let num_0 = ctx.size_t.const_int(0, false);
let num_1 = ctx.size_t.const_int(1, false);
@@ -94,17 +96,17 @@ fn matmul_at_least_2d<'ctx>(
(len, false),
|(), ctx, _, k| {
// `indices` is modified to index into `a` and `b`, and restored.
indices.set_unchecked(ctx, &at_row, i, None);
indices.set_unchecked(ctx, &at_col, k, None);
let a_ik = lhs.get_unchecked(ctx, &indices, None);
indices.set_unchecked(ctx, &at_row, i, None)?;
indices.set_unchecked(ctx, &at_col, k, None)?;
let a_ik = lhs.get_unchecked(ctx, &indices, None)?;
indices.set_unchecked(ctx, &at_row, k, None);
indices.set_unchecked(ctx, &at_col, j, None);
let b_kj = rhs.get_unchecked(ctx, &indices, None);
indices.set_unchecked(ctx, &at_row, k, None)?;
indices.set_unchecked(ctx, &at_col, j, None)?;
let b_kj = rhs.get_unchecked(ctx, &indices, None)?;
// Restore `indices`.
indices.set_unchecked(ctx, &at_row, i, None);
indices.set_unchecked(ctx, &at_col, j, None);
indices.set_unchecked(ctx, &at_row, i, None)?;
indices.set_unchecked(ctx, &at_col, j, None)?;
// x = a_[...]ik * b_[...]kj
let x = gen_prim_binop_expr(
@@ -116,7 +118,7 @@ fn matmul_at_least_2d<'ctx>(
.expect("matmul: ndarray should contain primtives only");
// dst_[...]ij += x
let dst_ij = typed_load(ctx.builder, pdst_ij, dst_dtype_llvm, "");
let dst_ij = typed_load(ctx.builder, pdst_ij, dst_dtype_llvm, "")?;
let dst_ij = gen_prim_binop_expr(
ctx,
(&Some(dst_dtype), dst_ij),
@@ -124,17 +126,16 @@ fn matmul_at_least_2d<'ctx>(
(&Some(dst_dtype), x),
)?
.expect("matmul: ndarray should contain primtives only");
typed_store(ctx.builder, pdst_ij, dst_ij);
typed_store(ctx.builder, pdst_ij, dst_ij)?;
Ok(())
},
num_1,
|(), _| Ok(()),
)
})
.unwrap();
})?;
dst
Ok(dst)
}
impl<'ctx> NDArrayValue<'ctx> {
@@ -144,14 +145,13 @@ impl<'ctx> NDArrayValue<'ctx> {
/// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar.
///
/// `dst_dtype` defines the dtype of the returned ndarray.
#[must_use]
pub fn matmul(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
self_ty: Type,
(other_ty, other): (Type, Self),
(out_dtype, out): (Type, NDArrayOut<'ctx>),
) -> Self {
) -> anyhow::Result<Self> {
// Sanity check, but type inference should prevent this.
assert!(self.ty.ndims > 0 && other.ty.ndims > 0, "np.matmul disallows scalar input");
@@ -170,19 +170,19 @@ impl<'ctx> NDArrayValue<'ctx> {
// Prepend 1 to its dimensions
self.index(ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
} else {
*self
};
Ok(*self)
}?;
let new_b = if other.ty.ndims == 1 {
// Append 1 to its dimensions
other.index(ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
} else {
other
};
Ok(other)
}?;
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result = matmul_at_least_2d(ctx, out_dtype, (self_ty, new_a), (other_ty, new_b));
let mut result = matmul_at_least_2d(ctx, out_dtype, (self_ty, new_a), (other_ty, new_b))?;
// Postprocessing on the result to remove prepended/appended axes.
let mut postindices = vec![];
@@ -200,19 +200,19 @@ impl<'ctx> NDArrayValue<'ctx> {
}
if !postindices.is_empty() {
result = result.index(ctx, &postindices);
result = result.index(ctx, &postindices)?;
}
match out {
Ok(match out {
NDArrayOut::NewNDArray { .. } => result,
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
let result_shape = result.shape(ctx);
let out_shape = out_ndarray.shape(ctx);
assert_ndarray_can_be_written_by_out(ctx, result_shape, out_shape);
let result_shape = result.shape(ctx)?;
let out_shape = out_ndarray.shape(ctx)?;
assert_ndarray_can_be_written_by_out(ctx, result_shape, out_shape)?;
out_ndarray.copy_data_from(ctx, &result);
out_ndarray.copy_data_from(ctx, &result)?;
out_ndarray
}
}
})
}
}

View File

@@ -3,7 +3,6 @@ use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
};
use itertools::Itertools as _;
use nac3core_derive::{ProxyType, StructFields};
use crate::{
@@ -68,9 +67,9 @@ impl<'ctx, S> Value<'ctx, NDArrayLikeType<'ctx, S>> {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
field: impl FnOnce(&NDArrayLikeType<'ctx, S>) -> StructField<'ctx, PointerValue<'ctx>>,
) -> ArraySliceValue<'ctx> {
let ptr = self.load(ctx, field);
ArraySliceValue::new(ctx.size_t.into(), ptr, self.ty.ndims_val(ctx), self.name)
) -> anyhow::Result<ArraySliceValue<'ctx>> {
let ptr = self.load(ctx, field)?;
Ok(ArraySliceValue::new(ctx.size_t.into(), ptr, self.ty.ndims_val(ctx), self.name))
}
}
@@ -122,46 +121,44 @@ impl<'ctx> NDArrayType<'ctx> {
/// Once you properly set up the `shape` array, you can construct a fully usable ndarray with
/// [`create_data`][NDArrayValue::create_data]. To construct a fully usable ndarray directly
/// when the shape is known, use [`NDArrayType::with_shape`].
#[must_use]
pub fn construct(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
let ndarray = self.alloca(ctx, name);
) -> anyhow::Result<NDArrayValue<'ctx>> {
let ndarray = self.alloca(ctx, name)?;
let size = self.itemsize_val(ctx);
ndarray.store(ctx, field!(itemsize), size);
ndarray.store(ctx, field!(itemsize), size)?;
let (ndims_int, ndims) = (self.ndims, self.ndims_val(ctx));
ndarray.store(ctx, field!(ndims), ndims);
ndarray.store(ctx, field!(ndims), ndims)?;
let shape = gen_array_var(ctx, ctx.size_t, ndims_int, None).value.0;
ndarray.store(ctx, field!(shape), shape);
let strides = gen_array_var(ctx, ctx.size_t, ndims_int, None).value.0;
ndarray.store(ctx, field!(strides), strides);
let shape = gen_array_var(ctx, ctx.size_t, ndims_int, None)?.value.0;
ndarray.store(ctx, field!(shape), shape)?;
let strides = gen_array_var(ctx, ctx.size_t, ndims_int, None)?.value.0;
ndarray.store(ctx, field!(strides), strides)?;
ndarray
Ok(ndarray)
}
/// Creates a new, contiguous `NDArrayValue` with a given shape.
///
/// The shape array is initialized to `shape`. The strides array is prepared accordingly.
/// The data array is allocated but uninitialized.
#[must_use]
pub fn with_shape(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[IntValue<'ctx>],
name: Option<&'static str>,
) -> NDArrayValue<'ctx> {
let ndarray = self.construct(ctx, name);
let dst = ndarray.shape(ctx);
) -> anyhow::Result<NDArrayValue<'ctx>> {
let ndarray = self.construct(ctx, name)?;
let dst = ndarray.shape(ctx)?;
for (i, &dim) in shape.iter().enumerate() {
let i = ctx.size_t.const_int(i as _, false);
dst.set_unchecked(ctx, &i, dim, name);
dst.set_unchecked(ctx, &i, dim, name)?;
}
ndarray.create_data(ctx);
ndarray
ndarray.create_data(ctx)?;
Ok(ndarray)
}
}
@@ -169,116 +166,130 @@ pub type NDArrayValue<'ctx> = Value<'ctx, NDArrayType<'ctx>>;
impl<'ctx> NDArrayValue<'ctx> {
/// Returns the shape of this array.
#[must_use]
pub fn shape(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> ArraySliceValue<'ctx> {
pub fn shape(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<ArraySliceValue<'ctx>> {
self.load_ndims_slice(ctx, field!(shape))
}
/// Returns the strides of this array.
#[must_use]
pub fn strides(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> ArraySliceValue<'ctx> {
pub fn strides(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<ArraySliceValue<'ctx>> {
self.load_ndims_slice(ctx, field!(strides))
}
/// Returns a new scalar `NDArrayValue` containing `value`.
///
/// The returned value has 0 dimensions.
#[must_use]
pub fn new_scalar(
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
name: Option<&'static str>,
) -> Self {
) -> anyhow::Result<Self> {
let dtype = value.get_type();
let ndarray = NDArrayType::new(ctx, dtype, 0).construct(ctx, name);
let data = gen_var(ctx, value.get_type(), Some("map_unsized"));
typed_store(ctx.builder, data, value);
let data = ctx.builder.build_pointer_cast(data, ctx.ptr, "").unwrap();
ndarray.store(ctx, field!(data), data);
ndarray
let ndarray = NDArrayType::new(ctx, dtype, 0).construct(ctx, name)?;
let data = gen_var(ctx, value.get_type(), Some("map_unsized"))?;
typed_store(ctx.builder, data, value)?;
let data = ctx.builder.build_pointer_cast(data, ctx.ptr, "")?;
ndarray.store(ctx, field!(data), data)?;
Ok(ndarray)
}
/// Computes the total number of (scalar) elements in this array.
#[must_use]
pub fn size(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let shape = self.shape(ctx);
pub fn size(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<IntValue<'ctx>> {
let shape = self.shape(ctx)?;
let mut product = ctx.size_t.const_int(1, false);
for i in 0..self.ty.ndims {
let idx = ctx.size_t.const_int(i, false);
let dim = shape.get_unchecked(ctx, &idx, None);
product = ctx.builder.build_int_mul(product, dim, "").unwrap();
let dim = shape.get_unchecked(ctx, &idx, None)?;
product = ctx.builder.build_int_mul(product, dim, "")?;
}
product
Ok(product)
}
/// Allocates contiguous memory for the data array and assigns strides correspondingly.
///
/// Assumes `shape` has been correctly prepared.
pub fn create_data(&self, ctx: &mut CodeGenContext<'ctx, '_>) {
let size = self.size(ctx);
let alloc = gen_dyn_array_var(ctx, self.ty.dtype, size, None).value.0;
self.store(ctx, field!(data), alloc);
self.set_strides_contiguous(ctx);
pub fn create_data(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<()> {
let size = self.size(ctx)?;
let alloc = gen_dyn_array_var(ctx, self.ty.dtype, size, None)?.value.0;
self.store(ctx, field!(data), alloc)?;
self.set_strides_contiguous(ctx)?;
Ok(())
}
/// Assigns strides for a contiguous array.
///
/// Assumes `shape` has been correctly prepared.
pub fn set_strides_contiguous(&self, ctx: &mut CodeGenContext<'ctx, '_>) {
let shape = self.shape(ctx);
let strides = self.strides(ctx);
pub fn set_strides_contiguous(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<()> {
let shape = self.shape(ctx)?;
let strides = self.strides(ctx)?;
let mut stride = self.ty.itemsize_val(ctx);
for i in (0..self.ty.ndims).rev() {
let idx = ctx.size_t.const_int(i, false);
strides.set_unchecked(ctx, &idx, stride, self.name);
let dim = shape.get_unchecked(ctx, &idx, None);
stride = ctx.builder.build_int_mul(stride, dim, "").unwrap();
strides.set_unchecked(ctx, &idx, stride, self.name)?;
let dim = shape.get_unchecked(ctx, &idx, None)?;
stride = ctx.builder.build_int_mul(stride, dim, "")?;
}
Ok(())
}
/// Returns the length of the first dimension of the array.
#[must_use]
pub fn len(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
pub fn len(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<IntValue<'ctx>> {
assert!(self.ty.ndims >= 1);
self.shape(ctx).get_unchecked(ctx, &ctx.size_t.const_zero(), self.name)
self.shape(ctx)?.get_unchecked(ctx, &ctx.size_t.const_zero(), self.name)
}
/// Returns the number of bytes consumed by the array data.
#[must_use]
pub fn nbytes(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let size = self.size(ctx);
pub fn nbytes(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<IntValue<'ctx>> {
let size = self.size(ctx)?;
let itemsize = self.ty.itemsize_val(ctx);
ctx.builder.build_int_mul(size, itemsize, "").unwrap()
Ok(ctx.builder.build_int_mul(size, itemsize, "")?)
}
/// Checks if the array is C-contiguous.
#[must_use]
pub fn is_c_contiguous(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
pub fn is_c_contiguous(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<IntValue<'ctx>> {
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
call_extern!(ctx: (ctx.i1) "is_c_contiguous" = name(self.value))
}
/// Creates a copy of this array.
#[must_use]
pub fn make_copy(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let shape = self.shape(ctx);
pub fn make_copy(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> anyhow::Result<Self> {
let shape = self.shape(ctx)?;
let clone =
NDArrayOut::NewNDArray { dtype: self.ty.dtype }.resolve(ctx, self.ty.ndims, shape);
clone.copy_data_from(ctx, self);
clone
NDArrayOut::NewNDArray { dtype: self.ty.dtype }.resolve(ctx, self.ty.ndims, shape)?;
clone.copy_data_from(ctx, self)?;
Ok(clone)
}
/// Copies data from `src` into this array.
pub fn copy_data_from(&self, ctx: &mut CodeGenContext<'ctx, '_>, src: &Self) {
pub fn copy_data_from(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
src: &Self,
) -> anyhow::Result<()> {
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data");
call_extern!(ctx: void _ = name(src.value, self.value));
call_extern!(ctx: void _ = name(src.value, self.value))?;
Ok(())
}
/// Copies the shape of `src` into this array.
pub fn copy_shape_from(&self, ctx: &mut CodeGenContext<'ctx, '_>, src: &Self) {
let shape = src.shape(ctx);
self.shape(ctx).memcpy_from(ctx, shape.value.0);
pub fn copy_shape_from(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
src: &Self,
) -> anyhow::Result<()> {
let shape = src.shape(ctx)?;
self.shape(ctx)?.memcpy_from(ctx, shape.value.0)?;
Ok(())
}
fn read_shape_or_stride_as_tuple(
@@ -286,61 +297,71 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
arr: ArraySliceValue<'ctx>,
name: &'static str,
) -> TupleValue<'ctx> {
// let types = vec![ctx.size_t.into(); self.ty.ndims as usize];
// let ty = TupleType::new(ctx, &types);
) -> anyhow::Result<TupleValue<'ctx>> {
let values = (0..self.ty.ndims)
.map(|i| {
let idx = ctx.size_t.const_int(i as _, false);
let val = arr.get_unchecked::<IntValue<'ctx>>(ctx, &idx, None);
ctx.builder.build_int_truncate_or_bit_cast(val, ctx.i32, "").unwrap()
let val = arr.get_unchecked::<IntValue<'ctx>>(ctx, &idx, None)?;
Ok(ctx.builder.build_int_truncate_or_bit_cast(val, ctx.i32, "")?)
})
.collect_vec();
.collect::<anyhow::Result<Vec<_>>>()?;
TupleValue::new(ctx, &values, Some(name))
}
/// Returns a `tuple` representing the shape of this array.
#[must_use]
pub fn make_shape_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> TupleValue<'ctx> {
let shape = self.shape(ctx);
pub fn make_shape_tuple(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<TupleValue<'ctx>> {
let shape = self.shape(ctx)?;
self.read_shape_or_stride_as_tuple(ctx, shape, "shape")
}
/// Returns a `tuple` representing the strides of this array.
#[must_use]
pub fn make_strides_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> TupleValue<'ctx> {
let strides = self.strides(ctx);
pub fn make_strides_tuple(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> anyhow::Result<TupleValue<'ctx>> {
let strides = self.strides(ctx)?;
self.read_shape_or_stride_as_tuple(ctx, strides, "strides")
}
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
/// Otherwise, do nothing and return the ndarray itself.
#[must_use]
pub fn split_unsized(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> ScalarOrNDArray<'ctx> {
if self.ty.ndims == 0 {
ScalarOrNDArray::Scalar(self.first_element(ctx))
pub fn split_unsized(
&self,