From a770d9f41598c8ef2cc5a4135664a5d0ff7255b9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Nov 2024 15:53:29 +0800 Subject: [PATCH] [core] coregen/types: Implement StructFields for NDArray Also rename some fields to better align with their naming in numpy. --- nac3artiq/src/codegen.rs | 10 +-- nac3core/src/codegen/builtin_fns.rs | 43 ++++++------- nac3core/src/codegen/expr.rs | 12 ++-- nac3core/src/codegen/irrt/ndarray.rs | 14 ++--- nac3core/src/codegen/numpy.rs | 73 ++++++++++------------ nac3core/src/codegen/types/ndarray.rs | 74 ++++++++++++++++++---- nac3core/src/codegen/values/ndarray.rs | 86 ++++++++++---------------- 7 files changed, 165 insertions(+), 147 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 8f7d2a8f..77340627 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -498,7 +498,7 @@ fn format_rpc_arg<'ctx>( call_memcpy_generic( ctx, pbuffer_dims_begin, - llvm_arg.dim_sizes().base_ptr(ctx, generator), + llvm_arg.shape().base_ptr(ctx, generator), dims_buf_sz, llvm_i1.const_zero(), ); @@ -612,7 +612,7 @@ fn format_rpc_ret<'ctx>( // Set `ndarray.ndims` ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); // Allocate `ndarray.shape` [size_t; ndims] - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); + ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx)); /* ndarray now: @@ -702,7 +702,7 @@ fn format_rpc_ret<'ctx>( call_memcpy_generic( ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), + ndarray.shape().base_ptr(ctx, generator), pbuffer_dims, sizeof_dims, llvm_i1.const_zero(), @@ -714,7 +714,7 @@ fn format_rpc_ret<'ctx>( // `ndarray.shape` must be initialized beforehand in this implementation // (for ndarray.create_data() to know how many elements to allocate) let num_elements = - call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None)); // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { @@ -1373,7 +1373,7 @@ fn polymorphic_print<'ctx>( llvm_usize, None, ); - let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); + let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index e693faff..20d35006 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -78,7 +78,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( None, ); - let ndims = arg.dim_sizes().size(ctx, generator); + let ndims = arg.shape().size(ctx, generator); ctx.make_assert( generator, ctx.builder @@ -91,12 +91,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ); let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) + arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() @@ -927,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx .builder @@ -1981,12 +1976,12 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2023,12 +2018,12 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2074,12 +2069,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2128,12 +2123,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2171,12 +2166,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2214,12 +2209,12 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2284,12 +2279,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let n2_array = n2_array.as_base_value().as_basic_value_enum(); let outdim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let outdim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2362,7 +2357,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; @@ -2405,7 +2400,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f4ab5fba..489eaa92 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2631,7 +2631,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let llvm_i32 = ctx.ctx.i32_type(); let len = unsafe { - v.dim_sizes().get_typed_unchecked( + v.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(dim, true), @@ -2672,7 +2672,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ExprKind::Slice { lower, upper, step } => { let dim_sz = unsafe { - v.dim_sizes().get_typed_unchecked( + v.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(dim, false), @@ -2813,7 +2813,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); let ndarray_num_dims = ctx .builder @@ -2824,7 +2824,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ) .unwrap(); let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( + v.shape().ptr_offset_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -2833,7 +2833,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( }; call_memcpy_generic( ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), + ndarray.shape().base_ptr(ctx, generator), v_dims_src_ptr, ctx.builder .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") @@ -2845,7 +2845,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); let ndarray_num_elems = ctx diff --git a/nac3core/src/codegen/irrt/ndarray.rs b/nac3core/src/codegen/irrt/ndarray.rs index bfec1d56..541c1175 100644 --- a/nac3core/src/codegen/irrt/ndarray.rs +++ b/nac3core/src/codegen/irrt/ndarray.rs @@ -103,7 +103,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( }); let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); + let ndarray_dims = ndarray.shape(); let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); @@ -172,7 +172,7 @@ where }); let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); + let ndarray_dims = ndarray.shape(); let index = ctx .builder @@ -259,8 +259,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let (lhs_dim_sz, rhs_dim_sz) = unsafe { ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), + rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), ) }; @@ -298,9 +298,9 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( .unwrap(); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_dims = lhs.shape().base_ptr(ctx, generator); let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_dims = rhs.shape().base_ptr(ctx, generator); let rhs_ndims = rhs.load_ndims(ctx); let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); @@ -362,7 +362,7 @@ pub fn call_ndarray_calc_broadcast_index< let broadcast_size = broadcast_idx.size(ctx, generator); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_dims = array.shape().base_ptr(ctx, generator); let array_ndims = array.load_ndims(ctx); let broadcast_idx_ptr = unsafe { broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 5db4ac26..58869f2c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -128,7 +128,7 @@ where ndarray.store_ndims(ctx, generator, num_dims); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); // Copy the dimension sizes from shape to ndarray.dims let shape_len = shape_len_fn(generator, ctx, shape)?; @@ -144,7 +144,7 @@ where let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let ndarray_pdim = - unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; + unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); @@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ndarray.store_ndims(ctx, generator, num_dims); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); for (i, &shape_dim) in shape.iter().enumerate() { let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let ndarray_dim = unsafe { - ndarray.dim_sizes().ptr_offset_unchecked( + ndarray.shape().ptr_offset_unchecked( ctx, generator, &llvm_usize.const_int(i as u64, true), @@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); @@ -380,7 +380,7 @@ where let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); @@ -739,7 +739,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let stride = call_ndarray_calc_size( generator, ctx, - &dst_arr.dim_sizes(), + &dst_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); @@ -1155,7 +1155,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let stride = call_ndarray_calc_size( generator, ctx, - &src_arr.dim_sizes(), + &src_arr.shape(), (Some(llvm_usize.const_int(dim, false)), None), ); let stride = @@ -1173,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let src_stride = call_ndarray_calc_size( generator, ctx, - &src_arr.dim_sizes(), + &src_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); let dst_stride = call_ndarray_calc_size( generator, ctx, - &dst_arr.dim_sizes(), + &dst_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); @@ -1278,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( &this, |_, ctx, shape| Ok(shape.load_ndims(ctx)), |generator, ctx, shape, idx| unsafe { - Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, )? } else { @@ -1286,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); let ndims = this.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndims); + ndarray.create_shape(ctx, llvm_usize, ndims); // Populate the first slices.len() dimensions by computing the size of each dim slice for (i, (start, stop, step)) in slices.iter().enumerate() { @@ -1318,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); unsafe { - ndarray.dim_sizes().set_typed_unchecked( + ndarray.shape().set_typed_unchecked( ctx, generator, &llvm_usize.const_int(i as u64, false), @@ -1336,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( (this.load_ndims(ctx), false), |generator, ctx, _, idx| { unsafe { - let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); + let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None); + ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz); } Ok(()) @@ -1397,7 +1397,7 @@ where &operand, |_, ctx, v| Ok(v.load_ndims(ctx)), |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, ) .unwrap() @@ -1510,7 +1510,7 @@ where &ndarray, |_, ctx, v| Ok(v.load_ndims(ctx)), |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, ) .unwrap() @@ -1571,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( if let Some(res) = res { let res_ndims = res.load_ndims(ctx); let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( + res.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1582,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ) }; let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( + rhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1634,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) + lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; // lhs.dims[1] == rhs.dims[0] @@ -1681,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( }, |generator, ctx| { Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( + lhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_zero(), @@ -1691,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( }, |generator, ctx| { Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( + rhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1718,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( let common_dim = { let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( + lhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1726,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ) }; let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); @@ -2146,7 +2141,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); // Dimensions are reversed in the transposed array let out = create_ndarray_dyn_shape( @@ -2161,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .builder .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") .unwrap(); - unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } + unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) } }, ) .unwrap(); @@ -2198,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") .unwrap(); let dim = unsafe { - n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) + n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None) }; let rem_idx_val = @@ -2266,7 +2261,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2494,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ); // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); + let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), @@ -2556,8 +2551,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); ctx.make_assert( generator, diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index d6887322..46ffb5b8 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -1,11 +1,16 @@ +use inkwell::context::ContextRef; use inkwell::{ - context::Context, + context::{AsContextRef, Context}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; +use itertools::Itertools; -use super::ProxyType; +use super::{ + structure::{FieldIndexCounter, StructField, StructFields}, + ProxyType, +}; use crate::codegen::{ values::{ArraySliceValue, NDArrayValue, ProxyValue}, {CodeGenContext, CodeGenerator}, @@ -19,6 +24,38 @@ pub struct NDArrayType<'ctx> { llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct NDArrayStructFields<'ctx> { + pub ndims: StructField<'ctx, IntValue<'ctx>>, + pub shape: StructField<'ctx, PointerValue<'ctx>>, + pub data: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { + fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) }; + let mut counter = FieldIndexCounter::default(); + + NDArrayStructFields { + ndims: StructField::create(&mut counter, "ndims", llvm_usize), + shape: StructField::create( + &mut counter, + "shape", + llvm_usize.ptr_type(AddressSpace::default()), + ), + data: StructField::create( + &mut counter, + "data", + ctx.i8_type().ptr_type(AddressSpace::default()), + ), + } + } + + fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { + vec![self.ndims.into(), self.shape.into(), self.data.into()] + } +} + impl<'ctx> NDArrayType<'ctx> { /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. pub fn is_representable( @@ -86,19 +123,34 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) } + // TODO: Move this into e.g. StructProxyType + #[must_use] + fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { + NDArrayStructFields::new(ctx, llvm_usize) + } + + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields( + &self, + ctx: &'ctx Context, + llvm_usize: IntType<'ctx>, + ) -> NDArrayStructFields<'ctx> { + Self::fields(ctx, llvm_usize) + } + /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let field_tys = [ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - ctx.i8_type().ptr_type(AddressSpace::default()).into(), - ]; + // * data : Pointer to an array containing the array data + // * itemsize: The size of each NDArray elements in bytes + // * ndims : Number of dimensions in the array + // * shape : Pointer to an array containing the shape of the NDArray + // * strides : Pointer to an array indicating the number of bytes between each element at a dimension + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 732ed0d3..d5bedc4c 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -50,18 +50,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .ndims + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the number of dimensions `ndims` into this instance. @@ -83,59 +75,43 @@ impl<'ctx> NDArrayValue<'ctx> { ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() } - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + /// Returns the double-indirection pointer to the `shape` array, as if by calling + /// `getelementptr` on the field. + fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .shape + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); + fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( + pub fn create_shape( &self, ctx: &CodeGenContext<'ctx, '_>, llvm_usize: IntType<'ctx>, size: IntValue<'ctx>, ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); + self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); } /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) + pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> { + NDArrayShapeProxy(self) } /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ) - .unwrap() - } + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .data + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. @@ -194,15 +170,15 @@ impl<'ctx> From> for PointerValue<'ctx> { /// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. #[derive(Copy, Clone)] -pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); +pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> AnyTypeEnum<'ctx> { - self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() + self.0.shape().base_ptr(ctx, generator).get_type().get_element_type() } fn base_ptr( @@ -213,7 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); ctx.builder - .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) + .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) .map(BasicValueEnum::into_pointer_value) .unwrap() } @@ -227,7 +203,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { } } -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -266,10 +242,10 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> } } -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { fn downcast_to_type( &self, _: &mut CodeGenContext<'ctx, '_>, @@ -279,7 +255,7 @@ impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ct } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { fn upcast_from_type( &self, _: &mut CodeGenContext<'ctx, '_>, @@ -491,7 +467,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> let (dim_idx, dim_sz) = unsafe { ( indices.get_unchecked(ctx, generator, &i, None).into_int_value(), - self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), + self.0.shape().get_typed_unchecked(ctx, generator, &i, None), ) }; let dim_idx = ctx