[core] coregen/types: Implement StructFields for NDArray

Also rename some fields to better align with their naming in numpy.
This commit is contained in:
David Mak 2024-11-13 15:53:29 +08:00
parent b3b61959dd
commit 41fb27c913
7 changed files with 165 additions and 147 deletions

View File

@ -498,7 +498,7 @@ fn format_rpc_arg<'ctx>(
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
pbuffer_dims_begin, pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator), llvm_arg.shape().base_ptr(ctx, generator),
dims_buf_sz, dims_buf_sz,
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
@ -612,7 +612,7 @@ fn format_rpc_ret<'ctx>(
// Set `ndarray.ndims` // Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims] // Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
/* /*
ndarray now: ndarray now:
@ -702,7 +702,7 @@ fn format_rpc_ret<'ctx>(
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ndarray.dim_sizes().base_ptr(ctx, generator), ndarray.shape().base_ptr(ctx, generator),
pbuffer_dims, pbuffer_dims,
sizeof_dims, sizeof_dims,
llvm_i1.const_zero(), llvm_i1.const_zero(),
@ -714,7 +714,7 @@ fn format_rpc_ret<'ctx>(
// `ndarray.shape` must be initialized beforehand in this implementation // `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate) // (for ndarray.create_data() to know how many elements to allocate)
let num_elements = let num_elements =
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
@ -1373,7 +1373,7 @@ fn polymorphic_print<'ctx>(
llvm_usize, llvm_usize,
None, None,
); );
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last = let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();

View File

@ -78,7 +78,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let ndims = arg.dim_sizes().size(ctx, generator); let ndims = arg.shape().size(ctx, generator);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder ctx.builder
@ -91,12 +91,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
); );
let len = unsafe { let len = unsafe {
arg.dim_sizes().get_typed_unchecked( arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}; };
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
@ -927,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx let n_sz_eqz = ctx
.builder .builder
@ -1981,12 +1976,12 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2023,12 +2018,12 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2074,12 +2069,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2128,12 +2123,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2171,12 +2166,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2214,12 +2209,12 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe { let dim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2284,12 +2279,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
let n2_array = n2_array.as_base_value().as_basic_value_enum(); let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe { let outdim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let outdim1 = unsafe { let outdim1 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value() .into_int_value()
}; };
@ -2362,7 +2357,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
@ -2405,7 +2400,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.dim_sizes() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };

View File

@ -2631,7 +2631,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe { let len = unsafe {
v.dim_sizes().get_typed_unchecked( v.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(dim, true), &llvm_usize.const_int(dim, true),
@ -2672,7 +2672,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ExprKind::Slice { lower, upper, step } => { ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe { let dim_sz = unsafe {
v.dim_sizes().get_typed_unchecked( v.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(dim, false), &llvm_usize.const_int(dim, false),
@ -2813,7 +2813,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
); );
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx let ndarray_num_dims = ctx
.builder .builder
@ -2824,7 +2824,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
) )
.unwrap(); .unwrap();
let v_dims_src_ptr = unsafe { let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked( v.shape().ptr_offset_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -2833,7 +2833,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
}; };
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ndarray.dim_sizes().base_ptr(ctx, generator), ndarray.shape().base_ptr(ctx, generator),
v_dims_src_ptr, v_dims_src_ptr,
ctx.builder ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
@ -2845,7 +2845,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
let ndarray_num_elems = ctx let ndarray_num_elems = ctx

View File

@ -103,7 +103,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
}); });
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
@ -172,7 +172,7 @@ where
}); });
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); let ndarray_dims = ndarray.shape();
let index = ctx let index = ctx
.builder .builder
@ -259,8 +259,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe { let (lhs_dim_sz, rhs_dim_sz) = unsafe {
( (
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
) )
}; };
@ -298,9 +298,9 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
.unwrap(); .unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx); let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx); let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
@ -362,7 +362,7 @@ pub fn call_ndarray_calc_broadcast_index<
let broadcast_size = broadcast_idx.size(ctx, generator); let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator); let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx); let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe { let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)

View File

@ -128,7 +128,7 @@ where
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims // Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -144,7 +144,7 @@ where
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_pdim = let ndarray_pdim =
unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() { for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe { let ndarray_dim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked( ndarray.shape().ptr_offset_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(i as u64, true), &llvm_usize.const_int(i as u64, true),
@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
@ -380,7 +380,7 @@ where
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
@ -739,7 +739,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&dst_arr.dim_sizes(), &dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
@ -1155,7 +1155,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.shape(),
(Some(llvm_usize.const_int(dim, false)), None), (Some(llvm_usize.const_int(dim, false)), None),
); );
let stride = let stride =
@ -1173,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let src_stride = call_ndarray_calc_size( let src_stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
let dst_stride = call_ndarray_calc_size( let dst_stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&dst_arr.dim_sizes(), &dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
@ -1278,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
&this, &this,
|_, ctx, shape| Ok(shape.load_ndims(ctx)), |_, ctx, shape| Ok(shape.load_ndims(ctx)),
|generator, ctx, shape, idx| unsafe { |generator, ctx, shape, idx| unsafe {
Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
)? )?
} else { } else {
@ -1286,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
let ndims = this.load_ndims(ctx); let ndims = this.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndims); ndarray.create_shape(ctx, llvm_usize, ndims);
// Populate the first slices.len() dimensions by computing the size of each dim slice // Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() { for (i, (start, stop, step)) in slices.iter().enumerate() {
@ -1318,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
unsafe { unsafe {
ndarray.dim_sizes().set_typed_unchecked( ndarray.shape().set_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(i as u64, false), &llvm_usize.const_int(i as u64, false),
@ -1336,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
unsafe { unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz);
} }
Ok(()) Ok(())
@ -1397,7 +1397,7 @@ where
&operand, &operand,
|_, ctx, v| Ok(v.load_ndims(ctx)), |_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe { |generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
) )
.unwrap() .unwrap()
@ -1510,7 +1510,7 @@ where
&ndarray, &ndarray,
|_, ctx, v| Ok(v.load_ndims(ctx)), |_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe { |generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
) )
.unwrap() .unwrap()
@ -1571,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if let Some(res) = res { if let Some(res) = res {
let res_ndims = res.load_ndims(ctx); let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe { let res_dim0 = unsafe {
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let res_dim1 = unsafe { let res_dim1 = unsafe {
res.dim_sizes().get_typed_unchecked( res.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1582,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
) )
}; };
let lhs_dim0 = unsafe { let lhs_dim0 = unsafe {
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let rhs_dim1 = unsafe { let rhs_dim1 = unsafe {
rhs.dim_sizes().get_typed_unchecked( rhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1634,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe { let lhs_dim1 = unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}; };
let rhs_dim0 = unsafe { let rhs_dim0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
// lhs.dims[1] == rhs.dims[0] // lhs.dims[1] == rhs.dims[0]
@ -1681,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
}, },
|generator, ctx| { |generator, ctx| {
Ok(Some(unsafe { Ok(Some(unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_zero(), &llvm_usize.const_zero(),
@ -1691,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
}, },
|generator, ctx| { |generator, ctx| {
Ok(Some(unsafe { Ok(Some(unsafe {
rhs.dim_sizes().get_typed_unchecked( rhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1718,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
let common_dim = { let common_dim = {
let lhs_idx1 = unsafe { let lhs_idx1 = unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1726,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
) )
}; };
let rhs_idx0 = unsafe { let rhs_idx0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
@ -2146,7 +2141,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape( let out = create_ndarray_dyn_shape(
@ -2161,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.builder .builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap(); .unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
}, },
) )
.unwrap(); .unwrap();
@ -2198,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap(); .unwrap();
let dim = unsafe { let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
}; };
let rem_idx_val = let rem_idx_val =
@ -2266,7 +2261,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2494,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
); );
// The new shape must be compatible with the old shape // The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
@ -2556,8 +2551,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
ctx.make_assert( ctx.make_assert(
generator, generator,

View File

@ -1,11 +1,16 @@
use inkwell::context::ContextRef;
use inkwell::{ use inkwell::{
context::Context, context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
}; };
use itertools::Itertools;
use super::ProxyType; use super::{
structure::{FieldIndexCounter, StructField, StructFields},
ProxyType,
};
use crate::codegen::{ use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue}, values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator}, {CodeGenContext, CodeGenerator},
@ -19,6 +24,38 @@ pub struct NDArrayType<'ctx> {
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
} }
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> {
pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = FieldIndexCounter::default();
NDArrayStructFields {
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
}
}
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![self.ndims.into(), self.shape.into(), self.data.into()]
}
}
impl<'ctx> NDArrayType<'ctx> { impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable( pub fn is_representable(
@ -86,19 +123,34 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use] #[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
// //
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data : Pointer to an array containing the array data // * data : Pointer to an array containing the array data
let field_tys = [ // * itemsize: The size of each NDArray elements in bytes
llvm_usize.into(), // * ndims : Number of dimensions in the array
llvm_usize.ptr_type(AddressSpace::default()).into(), // * shape : Pointer to an array containing the shape of the NDArray
ctx.i8_type().ptr_type(AddressSpace::default()).into(), // * strides : Pointer to an array indicating the number of bytes between each element at a dimension
]; let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }

View File

@ -50,18 +50,10 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); self.get_type()
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); .get_fields(ctx.ctx, self.llvm_usize)
.ndims
unsafe { .ptr_by_gep(ctx, self.value, self.name)
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
} }
/// Stores the number of dimensions `ndims` into this instance. /// Stores the number of dimensions `ndims` into this instance.
@ -83,59 +75,43 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
} }
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `shape` array, as if by calling
/// on the field. /// `getelementptr` on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); self.get_type()
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); .get_fields(ctx.ctx, self.llvm_usize)
.shape
unsafe { .ptr_by_gep(ctx, self.value, self.name)
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
} }
/// Stores the array of dimension sizes `dims` into this instance. /// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
} }
/// Convenience method for creating a new array storing dimension sizes with the given `size`. /// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes( pub fn create_shape(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>, size: IntValue<'ctx>,
) { ) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
} }
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use] #[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayDimsProxy(self) NDArrayShapeProxy(self)
} }
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); self.get_type()
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); .get_fields(ctx.ctx, self.llvm_usize)
.data
unsafe { .ptr_by_gep(ctx, self.value, self.name)
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
)
.unwrap()
}
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
@ -194,15 +170,15 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. /// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>( fn element_type<G: CodeGenerator + ?Sized>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
generator: &G, generator: &G,
) -> AnyTypeEnum<'ctx> { ) -> AnyTypeEnum<'ctx> {
self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
} }
fn base_ptr<G: CodeGenerator + ?Sized>( fn base_ptr<G: CodeGenerator + ?Sized>(
@ -213,7 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder ctx.builder
.build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) .build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap() .unwrap()
} }
@ -227,7 +203,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
} }
} }
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>( unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -266,10 +242,10 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
} }
} }
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type( fn downcast_to_type(
&self, &self,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
@ -279,7 +255,7 @@ impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ct
} }
} }
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type( fn upcast_from_type(
&self, &self,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
@ -491,7 +467,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
let (dim_idx, dim_sz) = unsafe { let (dim_idx, dim_sz) = unsafe {
( (
indices.get_unchecked(ctx, generator, &i, None).into_int_value(), indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
) )
}; };
let dim_idx = ctx let dim_idx = ctx