ndstrides: [14] Final cleanups #524

Open
lyken wants to merge 4 commits from ndstrides-14-end into ndstrides-13-tests
9 changed files with 251 additions and 2838 deletions

View File

@ -14,26 +14,28 @@ use pyo3::{
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{ classes::{ListValue, RangeValue, UntypedArrayLikeAccessor},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
},
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, model::*,
object::{any::AnyObject, ndarray::NDArrayObject},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{ inkwell::{
context::Context, context::Context,
module::Linkage, module::Linkage,
types::{BasicType, IntType}, types::IntType,
values::{BasicValueEnum, IntValue, PointerValue, StructValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}, },
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}, nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, toplevel::{
helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys,
DefinitionId, GenCall,
},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
@ -454,55 +456,42 @@ fn format_rpc_arg<'ctx>(
// NAC3: NDArray = { usize, usize*, T* } // NAC3: NDArray = { usize, usize*, T* }
// libproto_artiq: NDArray = [data[..], dim_sz[..]] // libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type(); let ndarray = AnyObject { ty: arg_ty, value: arg };
let llvm_usize = generator.get_size_type(ctx.ctx); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let dtype = ctx.get_llvm_type(generator, ndarray.dtype);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
.builder // the reader.
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") // Turning it into a ContiguousNDArray to get a `data` that is contiguous.
.unwrap(); let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype));
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz = let sizeof_sizet = Int(SizeT).size_of(generator, ctx.ctx);
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet);
let buffer_size = let sizeof_pdata = Ptr(Any(dtype)).size_of(generator, ctx.ctx);
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata);
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims);
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata);
call_memcpy_generic( // buf = { data: void*, shape: [size_t; ndims]; }
ctx, let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value);
buffer.base_ptr(ctx, generator), let buf_data = buf;
llvm_arg.ptr_to_data(ctx), let buf_shape = buf_data.offset(ctx, sizeof_pdata.value);
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin = // Write to `buf->data`
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr<Any>
call_memcpy_generic( let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte));
ctx, buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value);
pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator) // Write to `buf->shape`
let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte));
buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value);
buf.value
} }
_ => { _ => {
@ -563,8 +552,10 @@ fn format_rpc_ret<'ctx>(
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type(); // FIXME: It is possible to rewrite everything more neatly with `Model<'ctx>`, but this is not too important.
let llvm_usize = generator.get_size_type(ctx.ctx);
let num_0 = Int(SizeT).const_0(generator, ctx.ctx);
let num_8 = Int(SizeT).const_int(generator, ctx.ctx, 8, false);
// Round `val` up to its modulo `power_of_two` // Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
@ -590,60 +581,36 @@ fn format_rpc_ret<'ctx>(
.unwrap() .unwrap()
}; };
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray // Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off. // A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
// Setup ndims // NOTE: Current content of `ndarray`:
let ndims = // - * `data` - **NOT YET** allocated.
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { // - * `itemsize` - initialized to be size_of(dtype).
assert_eq!(values.len(), 1); // - * `ndims` - initialized.
// - * `shape` - allocated; has uninitialized values.
// - * `strides` - allocated; has uninitialized values.
u64::try_from(values[0].clone()).unwrap() let itemsize = ndarray.instance.get(generator, ctx, |f| f.itemsize); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`.
} else { let dtype_llvm = ctx.get_llvm_type(generator, dtype);
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment // (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims = let sizeof_size_t = Int(SizeT).size_of(generator, ctx.ctx);
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_size_t = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_size_t); // sizeof(size_t)
let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); let sizeof_ptr = Ptr(Int(Byte)).size_of(generator, ctx.ctx);
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false)); let sizeof_ptr = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_ptr); // sizeof(uint8_t*)
let sizeof_shape = ndarray.ndims_llvm(generator, ctx.ctx).mul(ctx, sizeof_size_t); // sizeof([size_t; ndims]); same as the # of bytes of `ndarray.shape`.
// Size of the buffer for the initial `rpc_recv()`.
let unaligned_buffer_size = sizeof_ptr.add(ctx, sizeof_shape); // sizeof(uint8_t*) + sizeof([size_t; ndims]).
let buffer_size = round_up(ctx, unaligned_buffer_size.value, num_8.value);
let buffer_size = unsafe { Int(SizeT).believe_value(buffer_size) };
let stackptr = call_stacksave(ctx, None); let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment // Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
@ -651,9 +618,7 @@ fn format_rpc_ret<'ctx>(
.builder .builder
.build_array_alloca( .build_array_alloca(
llvm_i8_8, llvm_i8_8,
ctx.builder ctx.builder.build_int_unsigned_div(buffer_size.value, num_8.value, "").unwrap(),
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
.unwrap(),
"rpc.buffer", "rpc.buffer",
) )
.unwrap(); .unwrap();
@ -662,7 +627,7 @@ fn format_rpc_ret<'ctx>(
.build_bit_cast(buffer, llvm_pi8, "") .build_bit_cast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); let buffer = unsafe { Ptr(Int(Byte)).believe_value(buffer) };
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
// //
@ -670,24 +635,20 @@ fn format_rpc_ret<'ctx>(
let ndarray_nbytes = ctx let ndarray_nbytes = ctx
.build_call_or_invoke( .build_call_or_invoke(
rpc_recv, rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. &[buffer.value.into()], // Reads [usize; ndims]
"rpc.size.next", "rpc.size.next",
) )
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let ndarray_nbytes = unsafe { Int(SizeT).believe_value(ndarray_nbytes) };
// debug_assert(ndarray_nbytes > 0) // debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let cmp = ndarray_nbytes.compare(ctx, IntPredicate::UGT, num_0);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder cmp.value,
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next", "Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None], [None, None, None],
@ -696,49 +657,39 @@ fn format_rpc_ret<'ctx>(
} }
// Copy shape from the buffer to `ndarray.shape`. // Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims = // We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let pbuffer_shape = buffer.offset(ctx, sizeof_ptr.value);
let pbuffer_shape = pbuffer_shape.pointer_cast(generator, ctx, Int(SizeT));
// Copy shape from buffer to `ndarray.shape`
ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape);
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer // Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`. // Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation // `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate) // (for ndarray.create_data() to know how many elements to allocate)
let num_elements = ndarray.create_data(generator, ctx); // NOTE: the strides of `ndarray` has also been set to contiguous in `::create_data()`.
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data = let num_elements = ndarray.size(generator, ctx);
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
let expected_ndarray_nbytes = num_elements.mul(ctx, itemsize);
let cmp = expected_ndarray_nbytes.compare(ctx, IntPredicate::UGE, ndarray_nbytes);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::UGE, cmp.value,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None], [Some(expected_ndarray_nbytes.value), Some(ndarray_nbytes.value), None],
ctx.current_loc, ctx.current_loc,
); );
} }
ndarray.create_data(ctx, llvm_elem_ty, num_elements); let ndarray_data = ndarray.instance.get(generator, ctx, |f| f.data);
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb` // NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
@ -747,7 +698,7 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(head_bb); ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); phi.add_incoming(&[(&ndarray_data.value, prehead_bb)]);
let alloc_size = ctx let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
@ -762,12 +713,12 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(alloc_bb); ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T) // Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); let alloc_size = round_up(ctx, alloc_size, itemsize.value);
let alloc_ptr = ctx let alloc_ptr = ctx
.builder .builder
.build_array_alloca( .build_array_alloca(
llvm_elem_ty, dtype_llvm,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), ctx.builder.build_int_unsigned_div(alloc_size, itemsize.value, "").unwrap(),
"rpc.alloc", "rpc.alloc",
) )
.unwrap(); .unwrap();
@ -777,7 +728,7 @@ fn format_rpc_ret<'ctx>(
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb); ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into() ndarray.instance.value.as_basic_value_enum()
} }
_ => { _ => {
@ -1359,56 +1310,46 @@ fn polymorphic_print<'ctx>(
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
fmt.push_str("array(["); fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); let ndarray = AnyObject { ty, value };
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
gen_for_callback_incrementing( let num_0 = Int(SizeT).const_0(generator, ctx.ctx);
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
polymorphic_print( // Print `ndarray` as a flat list delimited by interspersed with ", \0"
ctx, ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
generator, let i = hdl.get_index(generator, ctx);
&[(elem_ty, elem.into())], let scalar = hdl.get_scalar(generator, ctx);
"",
None,
true,
as_rtio,
)?;
// if (i != 0) { puts(", "); }
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|_, ctx| { |_, ctx| {
Ok(ctx let not_first = i.compare(ctx, IntPredicate::NE, num_0);
.builder Ok(not_first.value)
.build_int_compare(IntPredicate::ULT, i, last, "")
.unwrap())
}, },
|generator, ctx| { |generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default()); printf(ctx, generator, ", \0".into(), Vec::default());
Ok(()) Ok(())
}, },
|_, _| Ok(()), |_, _| Ok(()),
)?; )?;
Ok(()) // Print element
}, polymorphic_print(
llvm_usize.const_int(1, false), ctx,
generator,
&[(scalar.ty, scalar.value.into())],
"",
None,
true,
as_rtio,
)?; )?;
Ok(())
})?;
fmt.push_str(")]"); fmt.push_str(")]");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);

View File

@ -10,18 +10,19 @@ use itertools::Itertools;
use parking_lot::RwLock; use parking_lot::RwLock;
use pyo3::{ use pyo3::{
types::{PyDict, PyTuple}, types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python, PyAny, PyErr, PyObject, PyResult, Python,
}; };
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{NDArrayType, ProxyType}, model::*,
object::ndarray::{make_contiguous_strides, NDArray},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{ inkwell::{
module::Linkage, module::Linkage,
types::{BasicType, BasicTypeEnum}, types::BasicType,
values::BasicValueEnum, values::{BasicValue, BasicValueEnum},
AddressSpace, AddressSpace,
}, },
nac3parser::ast::{self, StrRef}, nac3parser::ast::{self, StrRef},
@ -1088,15 +1089,12 @@ impl InnerResolver {
let (ndarray_dtype, ndarray_ndims) = let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx); let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype));
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
{ {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global( ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(), Struct(NDArray).llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
) )
@ -1116,100 +1114,138 @@ impl InnerResolver {
} else { } else {
todo!("Unpacking literal of more than one element unimplemented") todo!("Unpacking literal of more than one element unimplemented")
}; };
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { let Ok(ndims) = u64::try_from(ndarray_ndims) else {
unreachable!("Expected u64 value for ndarray_ndims") unreachable!("Expected u64 value for ndarray_ndims")
}; };
// Obtain the shape of the ndarray // Obtain the shape of the ndarray
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize); assert_eq!(shape_tuple.len(), ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
// The Rust type inferencer cannot figure this out
let shape_values: Result<Vec<Instance<'ctx, Int<SizeT>>>, PyErr> = shape_tuple
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, elem)| { .map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( let value = self
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
) .map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})?
.unwrap();
let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap();
Ok(value)
}) })
.collect(); .collect();
let shape_values = shape_values?.unwrap(); let shape_values = shape_values?;
let shape_values = llvm_usize.const_array(
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), // Also use this opportunity to get the constant values of `shape_values` for calculating strides.
); let shape_u64s = shape_values
.iter()
.map(|dim| {
assert!(dim.value.is_const());
dim.value.get_zero_extended_constant().unwrap()
})
.collect_vec();
let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values);
// create a global for ndarray.shape and initialize it using the shape // create a global for ndarray.shape and initialize it using the shape
let shape_global = ctx.module.add_global( let shape_global = ctx.module.add_global(
llvm_usize.array_type(ndarray_ndims as u32), Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".shape"), &(id_str.clone() + ".shape"),
); );
shape_global.set_initializer(&shape_values); shape_global.set_initializer(&shape_values.value);
// Obtain the (flattened) elements of the ndarray // Obtain the (flattened) elements of the ndarray
let sz: usize = obj.getattr("size")?.extract()?; let sz: usize = obj.getattr("size")?.extract()?;
let data: Result<Option<Vec<_>>, _> = (0..sz) let data_values: Vec<Instance<'ctx, Any>> = (0..sz)
.map(|i| { .map(|i| {
obj.getattr("flat")?.get_item(i).and_then(|elem| { obj.getattr("flat")?.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { let value = self
super::CompileError::new_err(format!("Error getting element {i}: {e}")) .get_obj_value(py, elem, ctx, generator, ndarray_dtype)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting element {i}: {e}"
))
})?
.unwrap();
let value = dtype.check_value(generator, ctx.ctx, value).unwrap();
Ok(value)
}) })
}) })
}) .try_collect()?;
.collect(); let data = dtype.const_array(generator, ctx.ctx, &data_values);
let data = data?.unwrap().into_iter();
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
}
BasicTypeEnum::IntType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
}
BasicTypeEnum::PointerType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
}
BasicTypeEnum::StructType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
}
BasicTypeEnum::VectorType(_) => unreachable!(),
};
// create a global for ndarray.data and initialize it using the elements // create a global for ndarray.data and initialize it using the elements
//
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
// We will have to cast it to an `u8*` later.
let data_global = ctx.module.add_global( let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32), Array { len: AnyLen(sz as u32), item: dtype }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".data"), &(id_str.clone() + ".data"),
); );
data_global.set_initializer(&data); data_global.set_initializer(&data.value);
// Get the constant itemsize.
let itemsize = dtype.llvm_type(generator, ctx.ctx).size_of().unwrap();
let itemsize = itemsize.get_zero_extended_constant().unwrap();
// Create the strides needed for ndarray.strides
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
let strides = strides
.into_iter()
.map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride, false))
.collect_vec();
let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides);
// create a global for ndarray.strides and initialize it
let strides_global = ctx.module.add_global(
Array { len: AnyLen(ndims as u32), item: Int(Byte) }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()),
&(id_str.clone() + ".strides"),
);
strides_global.set_initializer(&strides.value);
// create a global for the ndarray object and initialize it // create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ // We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs.
llvm_usize.const_int(ndarray_ndims, false).into(),
shape_global
.as_pointer_value()
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
.into(),
data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]);
let ndarray = ctx.module.add_global( // NOTE: data_global is an array of dtype, we want a `u8*`.
ndarray_llvm_ty.as_underlying_type(), let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap();
let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value);
let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize, false);
let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
let ndarray_shape =
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap();
let ndarray_strides =
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap();
let ndarray = Struct(NDArray).const_struct(
generator,
ctx.ctx,
&[
ndarray_data.value.as_basic_value_enum(),
ndarray_itemsize.value.as_basic_value_enum(),
ndarray_ndims.value.as_basic_value_enum(),
ndarray_shape.value.as_basic_value_enum(),
ndarray_strides.value.as_basic_value_enum(),
],
);
let ndarray_global = ctx.module.add_global(
Struct(NDArray).llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
); );
ndarray.set_initializer(&value); ndarray_global.set_initializer(&ndarray.value);
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray_global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {

View File

@ -2,7 +2,6 @@
#include "irrt/int_types.hpp" #include "irrt/int_types.hpp"
#include "irrt/list.hpp" #include "irrt/list.hpp"
#include "irrt/math.hpp" #include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/range.hpp" #include "irrt/range.hpp"
#include "irrt/slice.hpp" #include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/basic.hpp"

View File

@ -1,151 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -5,12 +5,7 @@ use inkwell::{
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use super::{ use super::{CodeGenContext, CodeGenerator};
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
/// A LLVM type that is used to represent a non-primitive type in NAC3. /// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> { pub trait ProxyType<'ctx>: Into<Self::Base> {
@ -1140,626 +1135,3 @@ impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
value.as_base_value() value.as_base_value()
} }
} }
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != 3 {
return Err(format!(
"Expected 3 fields in `NDArray`, got {}",
llvm_ndarray_ty.count_fields()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(_) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
Ok(())
}
/// Creates an instance of [`ListType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data: Pointer to an array containing the array data
let llvm_ndarray = ctx
.struct_type(
&[
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
],
false,
)
.ptr_type(AddressSpace::default());
NDArrayType::from_type(llvm_ndarray, llvm_usize)
}
/// Creates an [`NDArrayType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Underlying = StructType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.create_value(
generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(),
name,
)
}
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue { value, llvm_usize: self.llvm_usize, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
fn as_underlying_type(&self) -> Self::Underlying {
self.as_base_type().get_element_type().into_struct_type()
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
NDArrayType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
.create_value(ptr, name)
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
pub fn create_data(
&self,
ctx: &CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
) {
self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap());
}
/// Returns a proxy object to the field storing the data of this `NDArray`.
#[must_use]
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self)
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Underlying = StructValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
fn as_underlying_value(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Underlying {
ctx.builder
.build_load(self.as_base_value(), name.unwrap_or_default())
.map(BasicValueEnum::into_struct_value)
.unwrap()
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.data().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[*idx],
name.unwrap_or_default(),
)
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let data_sz = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx
.builder
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
.unwrap();
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
let indices_len = indices.size(ctx, generator);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe {
(
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx
.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt =
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) }
}
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}

View File

@ -3,7 +3,7 @@ use inkwell::{
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::{BasicTypeEnum, IntType}, types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
@ -12,18 +12,13 @@ use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use super::{ use super::{
classes::{ classes::{ArrayLikeValue, ListValue},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable, macros::codegen_unreachable,
model::{function::FnCall, *}, model::{function::FnCall, *},
object::{ object::{
list::List, list::List,
ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray},
}, },
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
@ -589,373 +584,6 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". // When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". // When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use] #[must_use]

File diff suppressed because it is too large Load Diff

View File

@ -652,3 +652,18 @@ impl<'ctx> NDArrayOut<'ctx> {
} }
} }
} }
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
///
/// This function is used generating strides for globally defined contiguous ndarrays.
#[must_use]
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
let mut strides = Vec::with_capacity(ndims as usize);
let mut stride_product = 1u64;
for i in 0..ndims {
let axis = ndims - i - 1;
strides[axis as usize] = stride_product * itemsize;
stride_product *= shape[axis as usize];
}
strides
}

View File

@ -16,7 +16,7 @@ use nac3parser::{
use parking_lot::RwLock; use parking_lot::RwLock;
use super::{ use super::{
classes::{ListType, NDArrayType, ProxyType, RangeType}, classes::{ListType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore, concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry, DefaultCodeGenerator, WithCall, WorkerRegistry,
@ -462,15 +462,3 @@ fn test_classes_range_type_new() {
let llvm_range = RangeType::new(&ctx); let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok()); assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
} }
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}