Compare commits

...

19 Commits

Author SHA1 Message Date
2cfb8baae5 [artiq] Reimplement get_obj_value for strided ndarray
Based on 7ef93472: artiq: reimplement get_obj_value to use ndarray with
strides
2024-11-29 17:59:49 +08:00
b40e9bca28 [artiq] codegen: Reimplement polymorphic_print for strided ndarray
Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
2024-11-29 17:27:57 +08:00
bbc68b8b1a [core] codegen: implement ndarray iterator NDIter
Based on 50f960ab: core/ndstrides: implement ndarray iterator NDIter

A necessary utility to iterate through all elements in a possibly
strided ndarray.
2024-11-29 17:27:14 +08:00
73b0f2bcc9 [artiq] codegen: Reimplement polymorphic_print for strided ndarray
Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
2024-11-29 17:27:13 +08:00
e965a7c7ce [core] codegen: Implement ContiguousNDArray
Fixes compatibility with linalg algorithms. matrix_power is missing due
to the need for indexing support.
2024-11-29 17:25:50 +08:00
57da0f67d1 [core] codegen: Implement NDArray functions from a0a1f35b 2024-11-29 17:25:50 +08:00
624e943cd6 [core] codegen/irrt: Add IRRT functions for strided-ndarray 2024-11-29 17:25:50 +08:00
a99ae4828a [core] Add itemsize and strides to NDArray struct
Temporarily disable linalg ndarray tests as they are not ported to work
with strided-ndarray.
2024-11-29 17:25:50 +08:00
acfa81ff60 [core] codegen: Add helper functions for create+call functions
Replacement for various FnCall methods from legacy ndstrides
implementation.
2024-11-29 17:25:50 +08:00
35ef3c3f27 [core] codegen: Add call_memcpy_generic_array
Replacement for Instance<Ptr>::copy_from from legacy ndstrides
implementation.
2024-11-29 17:25:50 +08:00
cb6faeabb6 [core] Add type_aligned_alloca 2024-11-29 17:25:49 +08:00
47fba32926 [core] codegen/types: Add docs for NDArrayType::fields 2024-11-29 17:19:46 +08:00
4cfa848399 [core] Expose irrt::ndarray 2024-11-29 17:19:46 +08:00
355c051886 [core] codegen/ndarray: Cleanup
- Remove redundant size param
- Add *_field functions
2024-11-29 17:19:46 +08:00
363e1a1f84 [core] Move alloca and map_value of ProxyType to implementations
These functions may not be invokable by the same set of parameters as
some classes has associated states.
2024-11-29 17:19:46 +08:00
a3c1d469fc [core] codegen/types: Rename StructField::set_from_value 2024-11-29 17:19:46 +08:00
cf8d732532 [standalone] linalg: Fix function name in error message 2024-11-29 17:19:46 +08:00
814dda55d7 [meta] Remove all mentions of build_int_cast
build_int_cast performs signed extension or truncation depending on the
source and target int lengths. This is usually not what we want - We
want zero-extension instead.

Replace all instances of build_int_cast with
build_int_z_extend_or_bit_cast to fix this issue.
2024-11-29 17:19:43 +08:00
10894085bb [core] codegen: Move ndarray type/value as a separate module 2024-11-29 15:44:16 +08:00
31 changed files with 3690 additions and 1061 deletions

View File

@ -12,16 +12,17 @@ use pyo3::{
PyObject, PyResult, Python, PyObject, PyResult, Python,
}; };
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave},
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
types::{NDArrayType, ProxyType}, type_aligned_alloca,
types::NDArrayType,
values::{ values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue,
RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -34,12 +35,14 @@ use nac3core::{
}, },
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}, nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, toplevel::{
helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys,
DefinitionId, GenCall,
},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
/// The parallelism mode within a block. /// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
enum ParallelMode { enum ParallelMode {
@ -458,52 +461,49 @@ fn format_rpc_arg<'ctx>(
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); let dtype = ctx.get_llvm_type(generator, elem_ty);
let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None); let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims))
.map_value(arg.into_pointer_value(), None);
let llvm_usize_sizeof = ctx let ndims = llvm_usize.const_int(ndims, false);
.builder
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz = // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); // the reader.
// Turning it into a ContiguousNDArray to get a `data` that is contiguous.
let carray = ndarray.make_contiguous_ndarray(generator, ctx);
let buffer_size = let sizeof_usize = llvm_usize.size_of();
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); let sizeof_usize =
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap();
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let sizeof_pdata = dtype.ptr_type(AddressSpace::default()).size_of();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let sizeof_pdata =
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_pdata, llvm_usize, "").unwrap();
call_memcpy_generic( let sizeof_buf_shape = ctx.builder.build_int_mul(sizeof_usize, ndims, "").unwrap();
ctx, let sizeof_buf = ctx.builder.build_int_add(sizeof_buf_shape, sizeof_pdata, "").unwrap();
buffer.base_ptr(ctx, generator),
llvm_arg.ptr_to_data(ctx),
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin = // buf = { data: void*, shape: [size_t; ndims]; }
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let buf = ctx.builder.build_array_alloca(llvm_i8, sizeof_buf, "rpc.arg").unwrap();
call_memcpy_generic( let buf = ArraySliceValue::from_ptr_val(buf, sizeof_buf, Some("rpc.arg"));
ctx, let buf_data = buf.base_ptr(ctx, generator);
pbuffer_dims_begin, let buf_shape =
llvm_arg.shape().base_ptr(ctx, generator), unsafe { buf.ptr_offset_unchecked(ctx, generator, &sizeof_pdata, None) };
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator) // Write to `buf->data`
let carray_data = carray.load_data(ctx);
let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "").unwrap();
call_memcpy(ctx, buf_data, carray_data, sizeof_pdata, llvm_i1.const_zero());
// Write to `buf->shape`
let carray_shape = ndarray.shape().base_ptr(ctx, generator);
let carray_shape_i8 =
ctx.builder.build_pointer_cast(carray_shape, llvm_pi8, "").unwrap();
call_memcpy(ctx, buf_shape, carray_shape_i8, sizeof_buf_shape, llvm_i1.const_zero());
buf.base_ptr(ctx, generator)
} }
_ => { _ => {
@ -544,6 +544,8 @@ fn format_rpc_ret<'ctx>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
@ -564,8 +566,7 @@ fn format_rpc_ret<'ctx>(
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type(); let num_0 = llvm_usize.const_zero();
let llvm_usize = generator.get_size_type(ctx.ctx);
// Round `val` up to its modulo `power_of_two` // Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
@ -591,79 +592,49 @@ fn format_rpc_ret<'ctx>(
.unwrap() .unwrap()
}; };
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray // Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off. // A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims))
.construct_uninitialized(generator, ctx, llvm_usize.const_int(ndims, false), None);
// Setup ndims // NOTE: Current content of `ndarray`:
let ndims = // - * `data` - **NOT YET** allocated.
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { // - * `itemsize` - initialized to be size_of(dtype).
assert_eq!(values.len(), 1); // - * `ndims` - initialized.
// - * `shape` - allocated; has uninitialized values.
// - * `strides` - allocated; has uninitialized values.
u64::try_from(values[0].clone()).unwrap() let itemsize = ndarray.load_itemsize(ctx); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`.
} else {
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment // (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims = let sizeof_usize = llvm_usize.size_of();
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_usize =
ctx.builder.build_int_truncate_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap();
let sizeof_ptr = llvm_i8.ptr_type(AddressSpace::default()).size_of();
let sizeof_ptr =
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_ptr, llvm_usize, "").unwrap();
let sizeof_shape =
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), sizeof_usize, "").unwrap();
// Size of the buffer for the initial `rpc_recv()`.
let unaligned_buffer_size = let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); ctx.builder.build_int_add(sizeof_ptr, sizeof_shape, "").unwrap();
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
let stackptr = call_stacksave(ctx, None); let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment let buffer = type_aligned_alloca(
let buffer = ctx generator,
.builder ctx,
.build_array_alloca(
llvm_i8_8, llvm_i8_8,
ctx.builder unaligned_buffer_size,
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "") Some("rpc.buffer"),
.unwrap(), );
"rpc.buffer", let buffer = ArraySliceValue::from_ptr_val(buffer, unaligned_buffer_size, None);
)
.unwrap();
let buffer = ctx
.builder
.build_bit_cast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
// //
@ -671,7 +642,7 @@ fn format_rpc_ret<'ctx>(
let ndarray_nbytes = ctx let ndarray_nbytes = ctx
.build_call_or_invoke( .build_call_or_invoke(
rpc_recv, rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]
"rpc.size.next", "rpc.size.next",
) )
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
@ -679,16 +650,14 @@ fn format_rpc_ret<'ctx>(
// debug_assert(ndarray_nbytes > 0) // debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let cmp = ctx
.builder
.build_int_compare(IntPredicate::UGT, ndarray_nbytes, num_0, "")
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder cmp,
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next", "Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None], [None, None, None],
@ -697,49 +666,50 @@ fn format_rpc_ret<'ctx>(
} }
// Copy shape from the buffer to `ndarray.shape`. // Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims = // We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let pbuffer_shape =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &sizeof_ptr, None) };
let pbuffer_shape =
ctx.builder.build_pointer_cast(pbuffer_shape, llvm_pusize, "").unwrap();
// Copy shape from buffer to `ndarray.shape`
ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape);
call_memcpy_generic(
ctx,
ndarray.shape().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer // Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`. // Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation // `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate) // (for ndarray.create_data() to know how many elements to allocate)
let num_elements = unsafe { ndarray.create_data(generator, ctx) }; // NOTE: the strides of `ndarray` has also been set to contiguous in `create_data`.
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data = let num_elements = ndarray.size(generator, ctx);
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
let expected_ndarray_nbytes =
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
let cmp = ctx
.builder
.build_int_compare(
IntPredicate::UGE,
expected_ndarray_nbytes,
ndarray_nbytes,
"",
)
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::UGE, cmp,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None], [Some(expected_ndarray_nbytes), Some(ndarray_nbytes), None],
ctx.current_loc, ctx.current_loc,
); );
} }
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
let ndarray_data = ndarray.data().base_ptr(ctx, generator); let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb` // NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
@ -748,7 +718,7 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(head_bb); ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); phi.add_incoming(&[(&ndarray_data, prehead_bb)]);
let alloc_size = ctx let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
@ -763,12 +733,13 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(alloc_bb); ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T) // Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); let alloc_size = round_up(ctx, alloc_size, itemsize);
// TODO(Derppening): Candidate for refactor into type_aligned_alloca
let alloc_ptr = ctx let alloc_ptr = ctx
.builder .builder
.build_array_alloca( .build_array_alloca(
llvm_elem_ty, dtype_llvm,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), ctx.builder.build_int_unsigned_div(alloc_size, itemsize, "").unwrap(),
"rpc.alloc", "rpc.alloc",
) )
.unwrap(); .unwrap();
@ -1367,62 +1338,50 @@ fn polymorphic_print<'ctx>(
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
fmt.push_str("array(["); fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_pointer_value( let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
value.into_pointer_value(), let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
llvm_elem_ty, .map_value(value.into_pointer_value(), None);
llvm_usize,
None,
);
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
gen_for_callback_incrementing( let num_0 = llvm_usize.const_zero();
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
polymorphic_print( // Print `ndarray` as a flat list delimited by interspersed with ", \0"
ctx, ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
generator, let i = hdl.get_index(ctx);
&[(elem_ty, elem.into())], let scalar = hdl.get_scalar(ctx);
"",
None,
true,
as_rtio,
)?;
// if (i != 0) puts(", ");
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|_, ctx| { |_, ctx| {
Ok(ctx let not_first = ctx
.builder .builder
.build_int_compare(IntPredicate::ULT, i, last, "") .build_int_compare(IntPredicate::NE, i, num_0, "")
.unwrap()) .unwrap();
Ok(not_first)
}, },
|generator, ctx| { |generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default()); printf(ctx, generator, ", \0".into(), Vec::default());
Ok(()) Ok(())
}, },
|_, _| Ok(()), |_, _| Ok(()),
)?; )?;
Ok(()) // Print element
}, polymorphic_print(
llvm_usize.const_int(1, false), ctx,
generator,
&[(dtype, scalar.into())],
"",
None,
true,
as_rtio,
)?; )?;
Ok(())
})?;
fmt.push_str(")]"); fmt.push_str(")]");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);

View File

@ -10,12 +10,14 @@ use itertools::Itertools;
use parking_lot::RwLock; use parking_lot::RwLock;
use pyo3::{ use pyo3::{
types::{PyDict, PyTuple}, types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python, PyAny, PyErr, PyObject, PyResult, Python,
}; };
use super::PrimitivePythonId;
use nac3core::{ use nac3core::{
codegen::{ codegen::{
types::{NDArrayType, ProxyType}, types::{NDArrayType, ProxyType},
values::make_contiguous_strides,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{ inkwell::{
@ -37,8 +39,6 @@ use nac3core::{
}, },
}; };
use super::PrimitivePythonId;
pub enum PrimitiveValue { pub enum PrimitiveValue {
I32(i32), I32(i32),
I64(i64), I64(i64),
@ -1088,15 +1088,17 @@ impl InnerResolver {
let (ndarray_dtype, ndarray_ndims) = let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); let dtype = llvm_ndarray.element_type();
{ {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global( ctx.module.add_global(
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(), llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
) )
@ -1116,30 +1118,43 @@ impl InnerResolver {
} else { } else {
todo!("Unpacking literal of more than one element unimplemented") todo!("Unpacking literal of more than one element unimplemented")
}; };
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { let Ok(ndims) = u64::try_from(ndarray_ndims) else {
unreachable!("Expected u64 value for ndarray_ndims") unreachable!("Expected u64 value for ndarray_ndims")
}; };
// Obtain the shape of the ndarray // Obtain the shape of the ndarray
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize); assert_eq!(shape_tuple.len(), ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
// The Rust type inferencer cannot figure this out
let shape_values = shape_tuple
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, elem)| { .map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( let value = self
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
) .map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})?
.unwrap();
let value = value.into_int_value();
Ok(value)
}) })
.collect(); .collect::<Result<Vec<_>, PyErr>>()?;
let shape_values = shape_values?.unwrap();
let shape_values = llvm_usize.const_array( // Also use this opportunity to get the constant values of `shape_values` for calculating strides.
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), let shape_u64s = shape_values
); .iter()
.map(|dim| {
assert!(dim.is_const());
dim.get_zero_extended_constant().unwrap()
})
.collect_vec();
let shape_values = llvm_usize.const_array(&shape_values);
// create a global for ndarray.shape and initialize it using the shape // create a global for ndarray.shape and initialize it using the shape
let shape_global = ctx.module.add_global( let shape_global = ctx.module.add_global(
llvm_usize.array_type(ndarray_ndims as u32), llvm_usize.array_type(ndims as u32),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".shape"), &(id_str.clone() + ".shape"),
); );
@ -1147,17 +1162,25 @@ impl InnerResolver {
// Obtain the (flattened) elements of the ndarray // Obtain the (flattened) elements of the ndarray
let sz: usize = obj.getattr("size")?.extract()?; let sz: usize = obj.getattr("size")?.extract()?;
let data: Result<Option<Vec<_>>, _> = (0..sz) let data: Vec<_> = (0..sz)
.map(|i| { .map(|i| {
obj.getattr("flat")?.get_item(i).and_then(|elem| { obj.getattr("flat")?.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { let value = self
super::CompileError::new_err(format!("Error getting element {i}: {e}")) .get_obj_value(py, elem, ctx, generator, ndarray_dtype)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting element {i}: {e}"
))
})?
.unwrap();
assert_eq!(value.get_type(), dtype);
Ok(value)
}) })
}) })
}) .try_collect()?;
.collect(); let data = data.into_iter();
let data = data?.unwrap().into_iter(); let data = match dtype {
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => { BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
} }
@ -1182,38 +1205,68 @@ impl InnerResolver {
}; };
// create a global for ndarray.data and initialize it using the elements // create a global for ndarray.data and initialize it using the elements
//
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
// We will have to cast it to an `u8*` later.
let data_global = ctx.module.add_global( let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32), dtype.array_type(sz as u32),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".data"), &(id_str.clone() + ".data"),
); );
data_global.set_initializer(&data); data_global.set_initializer(&data);
// Get the constant itemsize.
let itemsize = dtype.size_of().unwrap();
let itemsize = itemsize.get_zero_extended_constant().unwrap();
// Create the strides needed for ndarray.strides
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
let strides =
strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec();
let strides = llvm_usize.const_array(&strides);
// create a global for ndarray.strides and initialize it
let strides_global = ctx.module.add_global(
llvm_i8.array_type(ndims as u32),
Some(AddressSpace::default()),
&format!("${id_str}.strides"),
);
strides_global.set_initializer(&strides);
// create a global for the ndarray object and initialize it // create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty
// NOTE: data_global is an array of dtype, we want a `u8*`.
let ndarray_data = data_global.as_pointer_value();
let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
let ndarray_itemsize = llvm_usize.const_int(itemsize, false);
let ndarray_ndims = llvm_usize.const_int(ndims, false);
let ndarray_shape = shape_global.as_pointer_value();
let ndarray_strides = strides_global.as_pointer_value();
let ndarray = llvm_ndarray
.as_base_type() .as_base_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type()
.const_named_struct(&[ .const_named_struct(&[
llvm_usize.const_int(ndarray_ndims, false).into(), ndarray_itemsize.into(),
shape_global ndarray_ndims.into(),
.as_pointer_value() ndarray_shape.into(),
.const_cast(llvm_usize.ptr_type(AddressSpace::default())) ndarray_strides.into(),
.into(), ndarray_data.into(),
data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]); ]);
let ndarray = ctx.module.add_global( let ndarray_global = ctx.module.add_global(
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(), llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
); );
ndarray.set_initializer(&value); ndarray_global.set_initializer(&ndarray);
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray_global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {

View File

@ -3,3 +3,6 @@
#include "irrt/math.hpp" #include "irrt/math.hpp"
#include "irrt/ndarray.hpp" #include "irrt/ndarray.hpp"
#include "irrt/slice.hpp" #include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"

View File

@ -2,6 +2,8 @@
#include "irrt/int_types.hpp" #include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace { namespace {
template<typename SizeT> template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {

View File

@ -0,0 +1,342 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Assert that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template<typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
*/
template<typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims,
const SizeT* ndarray_shape,
SizeT output_ndims,
const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Return the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template<typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++)
size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template<typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template<typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template<typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The length.
*/
template<typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
if (ndarray->ndims != 0) {
return ndarray->shape[0];
}
// numpy prohibits `__len__` on unsized objects
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
__builtin_unreachable();
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see ndarray's rules for C-contiguity:
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template<typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// References:
// - tinynumpy's implementation:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]:
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity:
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
void* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
void* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template<typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template<typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template<typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy when we see a contiguous segment.
// TODO: Handle overlapping.
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
const int64_t* ndarray_shape,
int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
return len(ndarray);
}
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
return len(ndarray);
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,51 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
/**
* @brief The NDArray object
*
* Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
*
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
* `data`. There are also minor differences in the struct layout.
*/
template<typename SizeT>
struct NDArray {
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values or contain 0.
*/
SizeT* strides;
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
};
} // namespace

View File

@ -0,0 +1,146 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
/**
* @brief Helper struct to enumerate through an ndarray *efficiently*.
*
* Example usage (in pseudo-code):
* ```
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
* NDIter nditer;
* nditer.initialize(my_ndarray);
* while (nditer.has_element()) {
* // This body is run 6 (= my_ndarray.size) times.
*
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
* print(nditer.indices);
*
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
* print(nditer.nth);
*
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
* print(*((double *) nditer.element))
*
* nditer.next(); // Go to next element.
* }
* ```
*
* Interesting cases:
* - If `my_ndarray.ndims` == 0, there is one iteration.
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
*/
template<typename SizeT>
struct NDIter {
// Information about the ndarray being iterated over.
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*
* Initially this is 0.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*
* Initially this points to first element of the ndarray.
*/
void* element;
/**
* @brief Cache for the product of shape.
*
* Could be 0 if `shape` has 0s in it.
*/
SizeT size;
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, void* element, SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
// `indices` starts on all 0s.
for (SizeT axis = 0; axis < ndims; axis++)
indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
// element as well.
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
}
// Is the current iteration valid?
// If true, then `element`, `indices` and `nth` contain details about the current element.
bool has_element() { return nth < size; }
// Go to the next element.
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) - strides[axis] * (shape[axis] - 1));
} else {
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) + strides[axis]);
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
return iter->has_element();
}
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
return iter->has_element();
}
void __nac3_nditer_next(NDIter<int32_t>* iter) {
iter->next();
}
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
iter->next();
}
}

View File

@ -14,6 +14,7 @@ use super::{
numpy, numpy,
numpy::ndarray_elementwise_unaryop_impl, numpy::ndarray_elementwise_unaryop_impl,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::NDArrayType,
values::{ values::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
@ -22,7 +23,7 @@ use super::{
}; };
use crate::{ use crate::{
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, PrimDef}, helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys, numpy::unpack_ndarray_var_tys,
}, },
typecheck::typedef::{Type, TypeEnum}, typecheck::typedef::{Type, TypeEnum},
@ -68,12 +69,14 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_pointer_value( let arg = NDArrayValue::from_pointer_value(
arg.into_pointer_value(), arg.into_pointer_value(),
ctx.get_llvm_type(generator, elem_ty), ctx.get_llvm_type(generator, elem_ty),
Some(ndims),
llvm_usize, llvm_usize,
None, None,
); );
@ -145,7 +148,8 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -153,7 +157,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int32, ctx.primitives.int32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -208,7 +212,8 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -216,7 +221,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int64, ctx.primitives.int64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -287,7 +292,8 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -295,7 +301,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint32, ctx.primitives.uint32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -355,7 +361,8 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -363,7 +370,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint64, ctx.primitives.uint64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -422,7 +429,8 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -430,7 +438,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
)?; )?;
@ -469,7 +477,8 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -477,7 +486,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -510,7 +519,8 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -518,7 +528,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
)?; )?;
@ -576,7 +586,8 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -584,7 +595,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.bool, ctx.primitives.bool,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| { |generator, ctx, val| {
let elem = call_bool(generator, ctx, (elem_ty, val))?; let elem = call_bool(generator, ctx, (elem_ty, val))?;
@ -631,7 +642,8 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -639,7 +651,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -682,7 +694,8 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -690,7 +703,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -918,11 +931,14 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n =
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None);
let n_sz =
irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx let n_sz_eqz = ctx
.builder .builder
@ -1126,7 +1142,8 @@ where
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let (arg_elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_arg_elem_ty = ctx.get_llvm_type(generator, arg_elem_ty); let llvm_arg_elem_ty = ctx.get_llvm_type(generator, arg_elem_ty);
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
@ -1135,7 +1152,13 @@ where
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(
x,
llvm_arg_elem_ty,
Some(ndims),
llvm_usize,
None,
),
|generator, ctx, elem_val| { |generator, ctx, elem_val| {
helper_call_numpy_unary_elementwise( helper_call_numpy_unary_elementwise(
generator, generator,
@ -1960,282 +1983,290 @@ fn build_output_struct<'ctx>(
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky"; const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_cholesky(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
} }
/// Invokes the `np_linalg_qr` linalg function /// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr"; const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
unimplemented!("{FN_NAME} operates on float type NdArrays only"); let ndims = extract_ndims(&ctx.unifier, ndims);
}; let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); if !x1.get_type().element_type().is_float_type() {
let dim0 = unsafe { unsupported_type(ctx, FN_NAME, &[x1_ty]);
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { q.create_data(generator, ctx) };
let r = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
unsafe { r.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let q_c = q.make_contiguous_ndarray(generator, ctx);
let r_c = r.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_qr(
ctx,
x1_c.as_base_value().into(),
q_c.as_base_value().into(),
r_c.as_base_value().into(),
None,
);
let q = q.as_base_value().into();
let r = r.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![q, r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} }
/// Invokes the `np_linalg_svd` linalg function /// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd"; const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
}; };
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1));
let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let dim0 = unsafe { let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
n1.shape() unsafe { u.create_data(generator, ctx) };
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) let s = out_ndarray1_ty.construct_dyn_shape(generator, ctx, &[dk], None);
.unwrap() unsafe { s.create_data(generator, ctx) };
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); let vh = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d1, d1], None);
unsafe { vh.create_data(generator, ctx) };
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let u_c = u.make_contiguous_ndarray(generator, ctx);
let s_c = s.make_contiguous_ndarray(generator, ctx);
let vh_c = vh.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_svd(
ctx,
x1_c.as_base_value().into(),
u_c.as_base_value().into(),
s_c.as_base_value().into(),
vh_c.as_base_value().into(),
None,
);
let u = u.as_base_value().into();
let s = s.as_base_value().into();
let vh = vh.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![u, s, vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
} }
/// Invokes the `np_linalg_inv` linalg function /// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv"; const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_inv(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
} }
/// Invokes the `np_linalg_pinv` linalg function /// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv"; const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_dyn_shape(generator, ctx, &[d0, d1], None);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_pinv(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
} }
/// Invokes the `sp_linalg_lu` linalg function /// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu"; const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { l.create_data(generator, ctx) };
let u = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
unsafe { u.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let l_c = l.make_contiguous_ndarray(generator, ctx);
let u_c = u.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_lu(
ctx,
x1_c.as_base_value().into(),
l_c.as_base_value().into(),
u_c.as_base_value().into(),
None,
);
let l = l.as_base_value().into();
let u = u.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![l, u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} }
/// Invokes the `np_linalg_matrix_power` linalg function /// Invokes the `np_linalg_matrix_power` linalg function
@ -2252,14 +2283,15 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call // Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape( let n2_array = numpy::create_ndarray_const_shape(
generator, generator,
@ -2305,122 +2337,156 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1))
.construct_const_shape(generator, ctx, &[1], None);
unsafe { det.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = det.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_det(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
// Get the determinant out of `out`
let det = unsafe { det.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(det)
} }
/// Invokes the `sp_linalg_schur` linalg function /// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur"; const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
assert_eq!(ndims, 2);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let t = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
t.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { t.create_data(generator, ctx) };
let z = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
z.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { z.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let t_c = t.make_contiguous_ndarray(generator, ctx);
let z_c = z.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_schur(
ctx,
x1_c.as_base_value().into(),
t_c.as_base_value().into(),
z_c.as_base_value().into(),
None,
);
let t = t.as_base_value().into();
let z = z.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![t, z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} }
/// Invokes the `sp_linalg_hessenberg` linalg function /// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>), (x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg"; const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 { let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
assert_eq!(ndims, 2);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")
.map(Into::into)
.unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let h = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
h.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { h.create_data(generator, ctx) };
let q = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
q.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { q.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let h_c = h.make_contiguous_ndarray(generator, ctx);
let q_c = q.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_hessenberg(
ctx,
x1_c.as_base_value().into(),
h_c.as_base_value().into(),
q_c.as_base_value().into(),
None,
);
let h = h.as_base_value().into();
let q = q.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![h, q]);
Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap())
} }

View File

@ -32,7 +32,7 @@ use super::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var, gen_var,
}, },
types::{ListType, ProxyType}, types::{ListType, NDArrayType},
values::{ values::{
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
TypedArrayLikeAccessor, UntypedArrayLikeAccessor, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
@ -42,8 +42,8 @@ use super::{
use crate::{ use crate::{
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::{extract_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::unpack_ndarray_var_tys,
DefinitionId, TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -1112,7 +1112,7 @@ pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
// List structure; type { ty*, size_t } // List structure; type { ty*, size_t }
let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty);
let list = arr_ty.new_value(generator, ctx, name); let list = arr_ty.alloca(generator, ctx, name);
let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap();
list.store_size(ctx, generator, length); list.store_size(ctx, generator, length);
@ -1559,8 +1559,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 { if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let (ndarray_dtype1, ndarray_ndims1) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); let (ndarray_dtype2, ndarray_ndims2) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
let ndarray_ndims1 = extract_ndims(&ctx.unifier, ndarray_ndims1);
let ndarray_ndims2 = extract_ndims(&ctx.unifier, ndarray_ndims2);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1570,12 +1572,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
left_val.into_pointer_value(), left_val.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
Some(ndarray_ndims1),
llvm_usize, llvm_usize,
None, None,
); );
let right_val = NDArrayValue::from_pointer_value( let right_val = NDArrayValue::from_pointer_value(
right_val.into_pointer_value(), right_val.into_pointer_value(),
llvm_ndarray_dtype2, llvm_ndarray_dtype2,
Some(ndarray_ndims2),
llvm_usize, llvm_usize,
None, None,
); );
@ -1625,12 +1629,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
let ndarray_ndims = extract_ndims(&ctx.unifier, ndarray_ndims);
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_val = NDArrayValue::from_pointer_value( let ndarray_val = NDArrayValue::from_pointer_value(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
Some(ndarray_ndims),
llvm_usize, llvm_usize,
None, None,
); );
@ -1822,12 +1828,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_ndims = extract_ndims(&ctx.unifier, ndarray_ndims);
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
let val = NDArrayValue::from_pointer_value( let val = NDArrayValue::from_pointer_value(
val.into_pointer_value(), val.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
Some(ndarray_ndims),
llvm_usize, llvm_usize,
None, None,
); );
@ -1916,8 +1924,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let (ndarray_dtype1, ndarray_ndims1) =
unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
let ndarray_ndims1 = extract_ndims(&ctx.unifier, ndarray_ndims1);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1926,6 +1936,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
lhs.into_pointer_value(), lhs.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
Some(ndarray_ndims1),
llvm_usize, llvm_usize,
None, None,
); );
@ -2549,7 +2560,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, ty: Type,
ndims: Type, ndims_ty: Type,
v: NDArrayValue<'ctx>, v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>, slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
@ -2557,7 +2568,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims_ty) else {
codegen_unreachable!(ctx) codegen_unreachable!(ctx)
}; };
@ -2590,14 +2601,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
_ => 1, _ => 1,
}; };
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
@ -2789,32 +2792,21 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
_ => { _ => {
// Accessing an element from a multi-dimensional `ndarray` // Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
let num_dims = extract_ndims(&ctx.unifier, ndims_ty) - 1;
// Create a new array, remove the top dimension from the dimension-size-list, and copy the // Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over // elements over
let subscripted_ndarray = let ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t, Some(num_dims))
let ndarray = NDArrayValue::from_pointer_value( .construct_uninitialized(
subscripted_ndarray, generator,
llvm_ndarray_data_t, ctx,
llvm_usize, llvm_usize.const_int(num_dims, false),
None, None,
); );
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx let ndarray_num_dims = ctx
.builder .builder
.build_int_z_extend_or_bit_cast( .build_int_z_extend_or_bit_cast(
@ -2842,7 +2834,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = ndarray::call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.shape().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
@ -2852,7 +2844,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
.builder .builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap(); .unwrap();
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); unsafe { ndarray.create_data(generator, ctx) };
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic( call_memcpy_generic(
@ -3539,6 +3531,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let ndarray_ndims = extract_ndims(&ctx.unifier, *ndims);
let llvm_ty = ctx.get_llvm_type(generator, *ty); let llvm_ty = ctx.get_llvm_type(generator, *ty);
let v = if let Some(v) = generator.gen_expr(ctx, value)? { let v = if let Some(v) = generator.gen_expr(ctx, value)? {
@ -3547,7 +3540,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); let v = NDArrayValue::from_pointer_value(
v,
llvm_ty,
Some(ndarray_ndims),
usize,
None,
);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
} }
@ -3598,3 +3597,93 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
_ => unimplemented!(), _ => unimplemented!(),
})) }))
} }
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
pub fn create_fn_and_call<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
(params, is_var_args): (&[BasicTypeEnum<'ctx>], bool),
args: &[BasicValueEnum<'ctx>],
call_value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let intrinsic_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
let params = params.iter().copied().map(BasicTypeEnum::into).collect_vec();
let fn_type = if let Some(ret_type) = ret_type {
ret_type.fn_type(params.as_slice(), is_var_args)
} else {
ctx.ctx.void_type().fn_type(params.as_slice(), is_var_args)
};
ctx.module.add_function(fn_name, fn_type, None)
});
if let Some(configure) = configure {
configure(&intrinsic_fn);
}
let args = args.iter().copied().map(BasicValueEnum::into).collect_vec();
ctx.builder
.build_call(intrinsic_fn, args.as_slice(), call_value_name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(Either::left)
.unwrap()
}
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
///
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
/// parameters and arguments to be specified as tuples to better indicate the expected type and
/// actual value of each parameter-argument pair of the call.
pub fn create_and_call_function<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
params: &[(BasicTypeEnum<'ctx>, BasicValueEnum<'ctx>)],
value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let param_tys = params.iter().map(|(ty, _)| ty).copied().map(BasicTypeEnum::into).collect_vec();
let arg_values =
params.iter().map(|(_, value)| value).copied().map(BasicValueEnum::into).collect_vec();
create_fn_and_call(
ctx,
fn_name,
ret_type,
(param_tys.as_slice(), false),
arg_values.as_slice(),
value_name,
configure,
)
}
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
///
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
/// only arguments to be specified and performs inference for the parameter types of the function
/// using [`BasicValueEnum::get_type`] on the arguments.
///
/// This function is recommended if it is known that all function arguments match the parameter
/// types of the invoked function.
pub fn infer_and_call_function<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
args: &[BasicValueEnum<'ctx>],
value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let param_tys = args.iter().map(BasicValueEnum::get_type).collect_vec();
create_fn_and_call(
ctx,
fn_name,
ret_type,
(param_tys.as_slice(), false),
args,
value_name,
configure,
)
}

View File

@ -13,12 +13,11 @@ use super::{CodeGenContext, CodeGenerator};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
pub use list::*; pub use list::*;
pub use math::*; pub use math::*;
pub use ndarray::*;
pub use slice::*; pub use slice::*;
mod list; mod list;
mod math; mod math;
mod ndarray; pub mod ndarray;
mod slice; mod slice;
#[must_use] #[must_use]
@ -60,6 +59,27 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
irrt_mod irrt_mod
} }
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
///
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
#[must_use]
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python. /// NO numeric slice in python.

View File

@ -0,0 +1,250 @@
use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function},
irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{NDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
shape: PointerValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
None,
None,
);
}
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray_ndims: IntValue<'ctx>,
ndarray_shape: PointerValue<'ctx>,
output_ndims: IntValue<'ctx>,
output_shape: IntValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[
(llvm_usize.into(), ndarray_ndims.into()),
(llvm_pusize.into(), ndarray_shape.into()),
(llvm_usize.into(), output_ndims.into()),
(llvm_pusize.into(), output_shape.into()),
],
None,
None,
);
}
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("size"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("nbytes"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("len"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
create_and_call_function(
ctx,
&name,
Some(llvm_i1.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("is_c_contiguous"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
index: IntValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type().as_base_type();
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[
(llvm_ndarray.into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.into()),
],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) {
let llvm_ndarray = ndarray.get_type().as_base_type();
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function(
ctx,
&name,
None,
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
None,
None,
);
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
infer_and_call_function(
ctx,
&name,
None,
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
None,
None,
);
}

View File

@ -0,0 +1,67 @@
use inkwell::{
values::{BasicValueEnum, IntValue},
AddressSpace,
};
use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function},
irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{nditer::NDIterValue, ArrayLikeValue, ArraySliceValue, NDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
create_and_call_function(
ctx,
&name,
None,
&[
(iter.get_type().as_base_type().into(), iter.as_base_value().into()),
(ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
],
None,
None,
);
}
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) -> IntValue<'ctx> {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
infer_and_call_function(
ctx,
&name,
Some(ctx.ctx.bool_type().into()),
&[iter.as_base_value().into()],
None,
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next");
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
}

View File

@ -15,6 +15,11 @@ use crate::codegen::{
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
pub use basic::*;
pub use iter::*;
mod basic;
mod iter;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size. /// calculated total size.
@ -77,7 +82,7 @@ where
/// `NDArray`. /// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>, index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
@ -201,8 +206,8 @@ where
/// `NDArray`. /// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for. /// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>( pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G, generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: &Index, indices: &Index,
) -> IntValue<'ctx> ) -> IntValue<'ctx>

View File

@ -201,6 +201,48 @@ pub fn call_memcpy_generic<'ctx>(
call_memcpy(ctx, dest, src, len, is_volatile); call_memcpy(ctx, dest, src, len, is_volatile);
} }
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
/// Moreover, `len` now refers to the number of elements (rather than bytes) to copy.
pub fn call_memcpy_generic_array<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let len = ctx.builder.build_int_z_extend_or_bit_cast(len, llvm_sizeof_expr_t, "").unwrap();
let len = ctx.builder.build_int_mul(len, src_elem_t.size_of().unwrap(), "").unwrap();
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function) /// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
/// ///
/// Arguments: /// Arguments:
@ -343,3 +385,25 @@ pub fn call_float_powi<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Invokes the [`llvm.ctpop`](https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic) intrinsic.
pub fn call_int_ctpop<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.ctpop";
let llvm_src_t = src.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -30,7 +30,11 @@ use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{ use crate::{
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{
helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys,
TopLevelContext, TopLevelDef,
},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -510,12 +514,13 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let (dtype, ndims) = unpack_ndarray_var_tys(unifier, ty);
let ndims = extract_ndims(unifier, ndims);
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
NDArrayType::new(generator, ctx, element_type).as_base_type().into() NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -1119,3 +1124,106 @@ fn gen_in_range_check<'ctx>(
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef { fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into() format!("__{}_va_count", &arg_name).into()
} }
/// Returns the alignment of the type.
///
/// This is necessary as `get_alignment` is not implemented as part of [`BasicType`].
pub fn get_type_alignment<'ctx>(ty: impl Into<BasicTypeEnum<'ctx>>) -> IntValue<'ctx> {
match ty.into() {
BasicTypeEnum::ArrayType(ty) => ty.get_alignment(),
BasicTypeEnum::FloatType(ty) => ty.get_alignment(),
BasicTypeEnum::IntType(ty) => ty.get_alignment(),
BasicTypeEnum::PointerType(ty) => ty.get_alignment(),
BasicTypeEnum::StructType(ty) => ty.get_alignment(),
BasicTypeEnum::VectorType(ty) => ty.get_alignment(),
}
}
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
/// given type.
///
/// The returned [`PointerValue`] will have a type of `i8*`, a size of at least `size`, and will be
/// aligned with the alignment of `align_ty`.
pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
align_ty: impl Into<BasicTypeEnum<'ctx>>,
size: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
/// Round `val` up to its modulo `power_of_two`.
fn round_up<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>,
) -> IntValue<'ctx> {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width(),
"`val` ({}) and `power_of_two` ({}) must be the same type",
val.get_type(),
power_of_two.get_type(),
);
let llvm_val_t = val.get_type();
let max_rem =
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
}
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let align_ty = align_ty.into();
let size = ctx.builder.build_int_z_extend_or_bit_cast(size, llvm_usize, "").unwrap();
debug_assert_eq!(
size.get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected size_t ({}) for parameter `size` of `aligned_alloca`, got {}",
llvm_usize,
size.get_type(),
);
let alignment = get_type_alignment(align_ty);
let alignment = ctx.builder.build_int_z_extend_or_bit_cast(alignment, llvm_usize, "").unwrap();
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
alignment_bitcount,
alignment_bitcount.get_type().const_int(1, false),
"",
)
.unwrap(),
"0:AssertionError",
"Expected power-of-two alignment for aligned_alloca, got {0}",
[Some(alignment), None, None],
ctx.current_loc,
);
}
let buffer_size = round_up(ctx, size, alignment);
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
// Just to be absolutely sure, alloca in [i8 x alignment] slices
let buffer = ctx.builder.build_array_alloca(align_ty, aligned_slices, "").unwrap();
ctx.builder
.build_bit_cast(buffer, llvm_pi8, name.unwrap_or_default())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}

View File

@ -3,14 +3,18 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools;
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
use super::{ use super::{
expr::gen_binop_expr_with_values, expr::gen_binop_expr_with_values,
irrt::{ irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast, calculate_len_for_slice_range,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, ndarray::{
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
},
}, },
llvm_intrinsics::{self, call_memcpy_generic}, llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable, macros::codegen_unreachable,
@ -27,7 +31,7 @@ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, PrimDef}, helper::{arraylike_flatten_element_type, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::unpack_ndarray_var_tys,
DefinitionId, DefinitionId,
}, },
typecheck::{ typecheck::{
@ -36,28 +40,6 @@ use crate::{
}, },
}; };
/// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx
.get_llvm_type(generator, ndarray_ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
}
/// Creates an `NDArray` instance from a dynamic shape. /// Creates an `NDArray` instance from a dynamic shape.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -83,6 +65,7 @@ where
) -> Result<IntValue<'ctx>, String>, ) -> Result<IntValue<'ctx>, String>,
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
// Assert that all dimensions are non-negative // Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -122,13 +105,10 @@ where
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
let num_dims = shape_len_fn(generator, ctx, shape)?; let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); .construct_uninitialized(generator, ctx, num_dims, None);
// Copy the dimension sizes from shape to ndarray.dims // Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -153,7 +133,7 @@ where
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); unsafe { ndarray.create_data(generator, ctx) };
Ok(ndarray) Ok(ndarray)
} }
@ -189,54 +169,15 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
// TODO: Disallow dim_sz > u32_MAX // TODO: Disallow dim_sz > u32_MAX
} }
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
let num_dims = llvm_usize.const_int(shape.len() as u64, false); let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64))
ndarray.store_ndims(ctx, generator, num_dims); .construct_dyn_shape(generator, ctx, shape, None);
unsafe { ndarray.create_data(generator, ctx) };
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.shape().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, true),
None,
)
};
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray) Ok(ndarray)
} }
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
ndarray: NDArrayValue<'ctx>,
) -> NDArrayValue<'ctx> {
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
ndarray
}
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -338,20 +279,24 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
// Get the length/size of the tuple, which also happens to be the value of `ndims`. // Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields(); let ndims = shape_tuple.get_type().count_fields();
let mut shape = Vec::with_capacity(ndims as usize); let shape = (0..ndims)
for dim_i in 0..ndims { .map(|dim_i| {
let dim = ctx ctx.builder
.builder
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.map(BasicValueEnum::into_int_value)
.map(|v| {
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
})
.unwrap() .unwrap()
.into_int_value(); })
.collect_vec();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
} }
BasicValueEnum::IntValue(shape_int) => { BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let shape_int =
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
@ -505,6 +450,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -517,6 +463,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -532,6 +479,7 @@ where
let lhs = NDArrayValue::from_pointer_value( let lhs = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -548,6 +496,7 @@ where
let rhs = NDArrayValue::from_pointer_value( let rhs = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -706,7 +655,8 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
{ {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
.load_ndims(ctx)
} }
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
@ -800,7 +750,8 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
_ => { _ => {
let lst_len = src_lst.load_size(ctx, None); let lst_len = src_lst.load_size(ctx, None);
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); let sizeof_elem =
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap();
let cpy_len = ctx let cpy_len = ctx
.builder .builder
@ -856,7 +807,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_representable(object, llvm_usize).is_ok() { if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
let ndarray = gen_if_else_expr_callback( let ndarray = gen_if_else_expr_callback(
generator, generator,
@ -932,6 +883,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
return Ok(NDArrayValue::from_pointer_value( return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
)); ));
@ -1207,7 +1159,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
.build_int_mul( .build_int_mul(
src_data_offset, src_data_offset,
ctx.builder ctx.builder
.build_int_cast(sizeof_elem, src_data_offset.get_type(), "") .build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
.unwrap(), .unwrap(),
"", "",
) )
@ -1220,7 +1172,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
.build_int_mul( .build_int_mul(
dst_data_offset, dst_data_offset,
ctx.builder ctx.builder
.build_int_cast(sizeof_elem, dst_data_offset.get_type(), "") .build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
.unwrap(), .unwrap(),
"", "",
) )
@ -1269,6 +1221,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = if slices.is_empty() { let ndarray = if slices.is_empty() {
create_ndarray_dyn_shape( create_ndarray_dyn_shape(
@ -1282,8 +1235,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
}, },
)? )?
} else { } else {
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); .construct_uninitialized(generator, ctx, this.load_ndims(ctx), None);
let ndims = this.load_ndims(ctx); let ndims = this.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndims); ndarray.create_shape(ctx, llvm_usize, ndims);
@ -1346,7 +1299,9 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
) )
.unwrap(); .unwrap();
ndarray_init_data(generator, ctx, elem_ty, ndarray) unsafe { ndarray.create_data(generator, ctx) };
ndarray
}; };
ndarray_sliced_copyto_impl( ndarray_sliced_copyto_impl(
@ -1465,6 +1420,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1473,6 +1429,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1499,6 +1456,7 @@ where
let ndarray = NDArrayValue::from_pointer_value( let ndarray = NDArrayValue::from_pointer_value(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -2061,6 +2019,7 @@ pub fn gen_ndarray_copy<'ctx>(
NDArrayValue::from_pointer_value( NDArrayValue::from_pointer_value(
this_arg.into_pointer_value(), this_arg.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
), ),
@ -2098,7 +2057,7 @@ pub fn gen_ndarray_fill<'ctx>(
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
context, context,
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, _| { |generator, ctx, _| {
let value = if value_arg.is_pointer_value() { let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
@ -2140,7 +2099,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
@ -2260,7 +2219,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2548,8 +2507,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));

View File

@ -471,6 +471,6 @@ fn test_classes_ndarray_type_new() {
let llvm_i32 = ctx.i32_type(); let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx); let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None);
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
} }

View File

@ -111,6 +111,31 @@ impl<'ctx> ListType<'ctx> {
.map(PointerType::get_element_type) .map(PointerType::get_element_type)
.unwrap() .unwrap()
} }
/// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.llvm_usize,
name,
)
}
/// Converts an existing value into a [`ListValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
}
} }
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
@ -137,25 +162,22 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx)) Self::is_representable(llvm_ty, generator.get_size_type(ctx))
} }
fn new_value<G: CodeGenerator + ?Sized>( fn raw_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self::Value { ) -> <Self::Value as ProxyValue<'ctx>>::Base {
self.map_value(
generator generator
.gen_var_alloc( .gen_var_alloc(
ctx, ctx,
self.as_base_type().get_element_type().into_struct_type().into(), self.as_base_type().get_element_type().into_struct_type().into(),
name, name,
) )
.unwrap(), .unwrap()
name,
)
} }
fn new_array_value<G: CodeGenerator + ?Sized>( fn array_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -172,14 +194,6 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
.unwrap() .unwrap()
} }
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
Self::Value::from_pointer_value(value, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base { fn as_base_type(&self) -> Self::Base {
self.ty self.ty
} }

View File

@ -1,3 +1,21 @@
//! This module contains abstraction over all intrinsic composite types of NAC3.
//!
//! # `raw_alloca` vs `alloca` vs `construct`
//!
//! There are three ways of creating a new object instance using the abstractions provided by this
//! module.
//!
//! - `raw_alloca`: Allocates the object on the stack, returning an instance of
//! [`impl BasicValue`][inkwell::values::BasicValue]. This is similar to a `malloc` expression in
//! C++ but the object is allocated on the stack.
//! - `alloca`: Similar to `raw_alloca`, but also wraps the allocated object with
//! [`<Self as ProxyType<'ctx>>::Value`][ProxyValue], and returns the wrapped object. The returned
//! object will not initialize any value or fields. This is similar to a type-safe `malloc`
//! expression in C++ but the object is allocated on the stack.
//! - `construct`: Similar to `alloca`, but performs some initialization on the value or fields of
//! the returned object. This is similar to a `new` expression in C++ but the object is allocated
//! on the stack.
use inkwell::{context::Context, types::BasicType, values::IntValue}; use inkwell::{context::Context, types::BasicType, values::IntValue};
use super::{ use super::{
@ -35,16 +53,17 @@ pub trait ProxyType<'ctx>: Into<Self::Base> {
llvm_ty: Self::Base, llvm_ty: Self::Base,
) -> Result<(), String>; ) -> Result<(), String>;
/// Creates a new value of this type. /// Creates a new value of this type, returning the LLVM instance of this value.
fn new_value<G: CodeGenerator + ?Sized>( fn raw_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self::Value; ) -> <Self::Value as ProxyValue<'ctx>>::Base;
/// Creates a new array value of this type. /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the
fn new_array_value<G: CodeGenerator + ?Sized>( /// resulting array.
fn array_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -52,13 +71,6 @@ pub trait ProxyType<'ctx>: Into<Self::Base> {
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx>; ) -> ArraySliceValue<'ctx>;
/// Converts an existing value into a [`ProxyValue`] of this type.
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value;
/// Returns the [base type][Self::Base] of this proxy. /// Returns the [base type][Self::Base] of this proxy.
fn as_base_type(&self) -> Self::Base; fn as_base_type(&self) -> Self::Base;
} }

View File

@ -1,258 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::{
structure::{StructField, StructFields},
ProxyType,
};
use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != 3 {
return Err(format!(
"Expected 3 fields in `NDArray`, got {}",
llvm_ndarray_ty.count_fields()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
let ndarray_data = ndarray_pdata.get_element_type();
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
));
};
if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
Ok(())
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
//
// * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,253 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::ProxyType;
use crate::{
codegen::{
types::structure::{FieldIndexCounter, StructField, StructFields},
values::{ArraySliceValue, ContiguousNDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
},
toplevel::numpy::unpack_ndarray_var_tys,
typecheck::typedef::Type,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ContiguousNDArrayType<'ctx> {
ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct ContiguousNDArrayFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> ContiguousNDArrayFields<'ctx> {
#[must_use]
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
ContiguousNDArrayFields {
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())),
}
}
}
impl<'ctx> ContiguousNDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let fields = ContiguousNDArrayFields::new(ctx, llvm_usize);
let llvm_expected_ty = fields.to_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!(
"Expected {} fields in `ContiguousNDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields()
));
}
llvm_expected_ty
.iter()
.enumerate()
.map(|(i, expected_ty)| {
(expected_ty.0, expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
})
.try_for_each(|(field_name, expected_ty, actual_ty)| {
if field_name == fields.data.name() {
if actual_ty.is_pointer_type() {
Ok(())
} else {
Err(format!("Expected T* for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
}
} else if expected_ty == actual_ty {
Ok(())
} else {
Err(format!("Expected {expected_ty} for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
}
})?;
Ok(())
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> ContiguousNDArrayFields<'ctx> {
ContiguousNDArrayFields::new_typed(item, llvm_usize)
}
/// See [`NDArrayType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> {
Self::fields(self.item, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
let field_tys =
Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ContiguousNDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
Self { ty: llvm_cndarray, item, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, item, llvm_usize }
}
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.item,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ContiguousNDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ContiguousNDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: ContiguousNDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,395 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::{
structure::{StructField, StructFields},
ProxyType,
};
use crate::{
codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
{CodeGenContext, CodeGenerator},
},
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type,
};
pub use contiguous::*;
mod contiguous;
pub mod nditer;
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> {
#[value_type(usize)]
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub strides: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!(
"Expected {} fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields()
));
}
llvm_expected_ty
.iter()
.enumerate()
.map(|(i, expected_ty)| {
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
})
.try_for_each(|(expected_ty, actual_ty)| {
if expected_ty == actual_ty {
Ok(())
} else {
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
}
})?;
Ok(())
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
/// See [`NDArrayType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
//
// * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
}
/// Creates an [`NDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndims = extract_ndims(&ctx.unifier, ndims);
NDArrayType {
ty: Self::llvm_type(ctx.ctx, llvm_usize),
dtype: llvm_dtype,
ndims: Some(ndims),
llvm_usize,
}
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.llvm_usize
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype
}
/// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.dtype,
self.ndims,
self.llvm_usize,
name,
)
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated onto the stack.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
#[must_use]
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.alloca(generator, ctx, name);
let itemsize = ctx
.builder
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap();
ndarray.store_itemsize(ctx, generator, itemsize);
ndarray.store_ndims(ctx, generator, ndims);
ndarray.create_shape(ctx, self.llvm_usize, ndims);
ndarray.create_strides(ctx, self.llvm_usize, ndims);
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[u64],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = self.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(shape.len() as u64, false),
name,
);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
let dim = self.llvm_usize.const_int(*dim, false);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
dim,
);
}
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[IntValue<'ctx>],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = self.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(shape.len() as u64, false),
name,
);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
assert_eq!(
dim.get_type(),
self.llvm_usize,
"Expected {} but got {}",
self.llvm_usize.print_to_string(),
dim.get_type().print_to_string()
);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
*dim,
);
}
}
ndarray
}
/// Converts an existing value into a [`NDArrayValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
self.dtype,
self.ndims,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,256 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::ProxyType;
use crate::codegen::{
irrt,
types::structure::{StructField, StructFields},
values::{nditer::NDIterValue, ArraySliceValue, NDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDIterType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDIterStructFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub strides: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub indices: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
pub nth: StructField<'ctx, IntValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub element: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
pub size: StructField<'ctx, IntValue<'ctx>>,
}
impl<'ctx> NDIterType<'ctx> {
/// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!(
"Expected {} fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields()
));
}
llvm_expected_ty
.iter()
.enumerate()
.map(|(i, expected_ty)| {
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
})
.try_for_each(|(expected_ty, actual_ty)| {
if expected_ty == actual_ty {
Ok(())
} else {
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
}
})?;
Ok(())
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> {
NDIterStructFields::new(ctx, llvm_usize)
}
/// See [`NDIterType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> {
Self::fields(ctx, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDIter`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDIter`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_nditer = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_nditer, llvm_usize }
}
/// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`.
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `nditer` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.llvm_usize
}
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
parent,
indices,
self.llvm_usize,
name,
)
}
/// Allocate an [`NDIter`] that iterates through the given `ndarray`.
#[must_use]
pub fn construct<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> <Self as ProxyType<'ctx>>::Value {
let nditer = self.raw_alloca(generator, ctx, None);
let ndims = ndarray.load_ndims(ctx);
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices =
generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap();
let nditer = <Self as ProxyType<'ctx>>::Value::from_pointer_value(
nditer,
ndarray,
indices,
self.llvm_usize,
None,
);
irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices);
nditer
}
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
parent,
indices,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDIterValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDIterType<'ctx>> for PointerType<'ctx> {
fn from(value: NDIterType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -76,6 +76,30 @@ impl<'ctx> RangeType<'ctx> {
pub fn value_type(&self) -> IntType<'ctx> { pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
} }
/// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
name,
)
}
/// Converts an existing value into a [`RangeValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, name)
}
} }
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
@ -102,25 +126,22 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
Self::is_representable(llvm_ty) Self::is_representable(llvm_ty)
} }
fn new_value<G: CodeGenerator + ?Sized>( fn raw_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self::Value { ) -> <Self::Value as ProxyValue<'ctx>>::Base {
self.map_value(
generator generator
.gen_var_alloc( .gen_var_alloc(
ctx, ctx,
self.as_base_type().get_element_type().into_struct_type().into(), self.as_base_type().get_element_type().into_struct_type().into(),
name, name,
) )
.unwrap(), .unwrap()
name,
)
} }
fn new_array_value<G: CodeGenerator + ?Sized>( fn array_alloca<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -137,16 +158,6 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
.unwrap() .unwrap()
} }
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue::from_pointer_value(value, name)
}
fn as_base_type(&self) -> Self::Base { fn as_base_type(&self) -> Self::Base {
self.ty self.ty
} }

View File

@ -103,6 +103,12 @@ where
StructField { index, name, ty: ty.into(), _value_ty: PhantomData } StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
} }
/// Returns the name of this field.
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32 /// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`. /// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep( pub fn ptr_by_array_gep(
@ -145,7 +151,7 @@ where
} }
/// Sets the value of this field for a given `obj`. /// Sets the value of this field for a given `obj`.
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) { pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value); obj.set_field_at_index(self.index, value);
} }

View File

@ -207,7 +207,7 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. /// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
type ValueDowncastFn<'ctx, T> = type ValueDowncastFn<'ctx, T> =
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>; Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T + 'ctx>;
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. /// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>; type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;

View File

@ -0,0 +1,206 @@
use inkwell::{
types::{BasicType, BasicTypeEnum, IntType},
values::{IntValue, PointerValue},
AddressSpace,
};
use super::{ArrayLikeValue, NDArrayValue, ProxyValue};
use crate::codegen::{
stmt::gen_if_callback,
types::{structure::StructField, ContiguousNDArrayType, NDArrayType},
CodeGenContext, CodeGenerator,
};
#[derive(Copy, Clone)]
pub struct ContiguousNDArrayValue<'ctx> {
value: PointerValue<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ContiguousNDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is
/// not an instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, item: dtype, llvm_usize, name }
}
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().ndims
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.ndims_field().set(ctx, self.as_base_value(), value, self.name);
}
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().shape
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.shape_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.shape_field().get(ctx, self.value, self.name)
}
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().data
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name)
}
}
impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ContiguousNDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
<Self as ProxyValue<'ctx>>::Type::from_type(
self.as_base_value().get_type(),
self.item,
self.llvm_usize,
)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ContiguousNDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ContiguousNDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Create a [`ContiguousNDArray`] from the contents of this ndarray.
///
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
///
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
/// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there.
///
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`]
/// will share memory with this ndarray.
///
/// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and
/// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`]
/// if you don't care/cannot know the [`Model`] in advance.
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ContiguousNDArrayValue<'ctx> {
let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype)
.alloca(generator, ctx, self.name);
// Set ndims and shape.
let ndims = self
.ndims
.map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false));
result.store_ndims(ctx, ndims);
let shape = self.shape();
result.store_shape(ctx, shape.base_ptr(ctx, generator));
gen_if_callback(
generator,
ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|_, ctx| {
// This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
|generator, ctx| {
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
// ndarray with contiguous `data`.
let copied_ndarray = self.make_copy(generator, ctx);
let data = copied_ndarray.data().base_ptr(ctx, generator);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
)
.unwrap();
result
}
/// Create an [`NDArrayObject`] from a [`ContiguousNDArray`].
///
/// The operation is super cheap. The newly created [`NDArrayObject`] will share the
/// same memory as the [`ContiguousNDArray`].
///
/// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite
/// the fact that the information should be contained within the [`ContiguousNDArray`].
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
carray: ContiguousNDArrayValue<'ctx>,
ndims: u64,
) -> Self {
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
// Allocate the resulting ndarray.
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims))
.construct_uninitialized(
generator,
ctx,
carray.llvm_usize.const_int(ndims, false),
carray.name,
);
// Copy shape and update strides
let shape = carray.load_shape(ctx);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.set_strides_contiguous(generator, ctx);
// Share data
let data = carray.load_data(ctx);
ndarray.store_data(
ctx,
ctx.builder
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(),
);
ndarray
}
}

View File

@ -9,18 +9,25 @@ use super::{
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}; };
use crate::codegen::{ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt,
llvm_intrinsics::call_int_umin, llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::NDArrayType, type_aligned_alloca,
types::{structure::StructField, NDArrayType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
pub use contiguous::*;
mod contiguous;
pub mod nditer;
/// Proxy type for accessing an `NDArray` value in LLVM. /// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> { pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>, value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
} }
@ -40,20 +47,22 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn from_pointer_value( pub fn from_pointer_value(
ptr: PointerValue<'ctx>, ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self { ) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, llvm_usize, name } NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
}
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).ndims
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the number of dimensions `ndims` into this instance. /// Stores the number of dimensions `ndims` into this instance.
@ -75,18 +84,40 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
} }
fn itemsize_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).itemsize
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
itemsize: IntValue<'ctx>,
) {
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name);
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.itemsize_field(ctx).get(ctx, self.value, self.name)
}
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).shape
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling /// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field. /// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the array of dimension sizes `dims` into this instance. /// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
} }
/// Convenience method for creating a new array storing dimension sizes with the given `size`. /// Convenience method for creating a new array storing dimension sizes with the given `size`.
@ -105,13 +136,48 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayShapeProxy(self) NDArrayShapeProxy(self)
} }
fn strides_field(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).strides
}
/// Returns the double-indirection pointer to the `strides` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of stride sizes `strides` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) {
self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name);
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_strides(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).data
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
@ -120,26 +186,28 @@ impl<'ctx> NDArrayValue<'ctx> {
.builder .builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(); .unwrap();
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
} }
/// Convenience method for creating a new array storing data elements with the given element /// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`. /// type `elem_ty` and `size`.
pub fn create_data( ///
/// The data buffer will be allocated on the stack, and is considered to be owned by this ndarray instance.
///
/// # Safety
///
/// The caller must ensure that `shape` and `itemsize` of this ndarray instance is initialized.
pub unsafe fn create_data<G: CodeGenerator + ?Sized>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, generator: &mut G,
elem_ty: BasicTypeEnum<'ctx>, ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
) { ) {
let itemsize = let nbytes = self.nbytes(generator, ctx);
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
// TODO: What about alignment? let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
self.store_data( self.store_data(ctx, data);
ctx,
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(), self.set_strides_contiguous(generator, ctx);
);
} }
/// Returns a proxy object to the field storing the data of this `NDArray`. /// Returns a proxy object to the field storing the data of this `NDArray`.
@ -147,6 +215,196 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self) NDArrayDataProxy(self)
} }
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
shape: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.shape().base_ptr(ctx, generator),
shape,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
assert_eq!(self.ndims, src_ndarray.ndims);
} else {
let self_ndims = self.load_ndims(ctx);
let src_ndims = src_ndarray.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(
IntPredicate::EQ,
self_ndims,
src_ndims,
""
).unwrap(),
"0:AssertionError",
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
[Some(self_ndims), Some(src_ndims), None],
ctx.current_loc
);
}
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
strides: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.strides().base_ptr(ctx, generator),
strides,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
assert_eq!(self.ndims, src_ndarray.ndims);
} else {
let self_ndims = self.load_ndims(ctx);
let src_ndims = src_ndarray.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(
IntPredicate::EQ,
self_ndims,
src_ndims,
""
).unwrap(),
"0:AssertionError",
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
[Some(self_ndims), Some(src_ndims), None],
ctx.current_loc
);
}
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
}
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
let clone = self.get_type().construct_uninitialized(
generator,
ctx,
self.ndims.map_or_else(
|| self.load_ndims(ctx),
|ndims| self.llvm_usize.const_int(ndims, false),
),
None,
);
let shape = self.shape();
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { clone.create_data(generator, ctx) };
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
) {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
}
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
@ -154,7 +412,12 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Type = NDArrayType<'ctx>; type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type { fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize) NDArrayType::from_type(
self.as_base_value().get_type(),
self.dtype,
self.ndims,
self.llvm_usize,
)
} }
fn as_base_value(&self) -> Self::Base { fn as_base_value(&self) -> Self::Base {
@ -265,6 +528,103 @@ impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ct
} }
} }
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.strides().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.strides")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_strides(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. /// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
@ -283,12 +643,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
@ -296,7 +651,12 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
generator: &G, generator: &G,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) irrt::ndarray::call_ndarray_calc_size(
generator,
ctx,
&self.as_slice_value(ctx, generator),
(None, None),
)
} }
} }
@ -405,7 +765,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices_elem_ty.get_bit_width() indices_elem_ty.get_bit_width()
); );
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices);
let sizeof_elem = ctx let sizeof_elem = ctx
.builder .builder
.build_int_truncate_or_bit_cast( .build_int_truncate_or_bit_cast(
@ -521,3 +881,18 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
for NDArrayDataProxy<'ctx, '_> for NDArrayDataProxy<'ctx, '_>
{ {
} }
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
///
/// This function is used generating strides for globally defined contiguous ndarrays.
#[must_use]
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
let mut strides = Vec::with_capacity(ndims as usize);
let mut stride_product = 1u64;
for i in 0..ndims {
let axis = ndims - i - 1;
strides[axis as usize] = stride_product * itemsize;
stride_product *= shape[axis as usize];
}
strides
}

View File

@ -0,0 +1,176 @@
use inkwell::{
types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator};
use crate::codegen::{
irrt,
stmt::{gen_for_callback, BreakContinueHooks},
types::{nditer::NDIterType, structure::StructField},
values::{ArraySliceValue, TypedArrayLikeAdapter},
CodeGenContext, CodeGenerator,
};
#[derive(Copy, Clone)]
pub struct NDIterValue<'ctx> {
value: PointerValue<'ctx>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDIterValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, parent, indices, llvm_usize, name }
}
/// Is the current iteration valid?
///
/// If true, then `element`, `indices` and `nth` contain details about the current element.
///
/// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false.
#[must_use]
pub fn has_element<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self)
}
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
///
/// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called.
pub fn next<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) {
irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self);
}
fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).element
}
/// Get pointer to the current element.
#[must_use]
pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let elem_ty = self.parent.dtype;
let p = self.element(ctx).get(ctx, self.as_base_value(), None);
ctx.builder
.build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
/// Get the value of the current element.
#[must_use]
pub fn get_scalar(&self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let p = self.get_pointer(ctx);
ctx.builder.build_load(p, "value").unwrap()
}
fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).nth
}
/// Get the index of the current element if this ndarray were a flat ndarray.
#[must_use]
pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.nth(ctx).get(ctx, self.as_base_value(), None)
}
/// Get the indices of the current element.
#[must_use]
pub fn get_indices(
&'ctx self,
) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>>
{
TypedArrayLikeAdapter::from(
self.indices,
Box::new(|ctx, val| {
ctx.builder
.build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "")
.unwrap()
}),
Box::new(|_, val| val.into()),
)
}
}
impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDIterType<'ctx>;
fn get_type(&self) -> Self::Type {
NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDIterValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDIterValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Iterate through every element in the ndarray.
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterValue<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| {
Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self))
},
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
}

View File

@ -92,7 +92,7 @@ pub unsafe extern "C" fn np_linalg_qr(
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_qr", file!(), line!(), column!(), &err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();

View File

@ -1758,13 +1758,12 @@ def run() -> int32:
test_ndarray_transpose() test_ndarray_transpose()
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_dot()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()
test_ndarray_svd() test_ndarray_svd()
test_ndarray_linalg_inv() test_ndarray_linalg_inv()
test_ndarray_pinv() test_ndarray_pinv()
test_ndarray_matrix_power() # test_ndarray_matrix_power()
test_ndarray_det() test_ndarray_det()
test_ndarray_lu() test_ndarray_lu()
test_ndarray_schur() test_ndarray_schur()