From 684aafe54c5228bb6b18839fff3fc41e57819e37 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 15 Nov 2024 15:24:07 +0800 Subject: [PATCH] [core] Add itemsize and strides to NDArray struct --- nac3artiq/src/codegen.rs | 1 + nac3core/src/codegen/builtin_fns.rs | 43 +-- nac3core/src/codegen/expr.rs | 8 +- nac3core/src/codegen/numpy.rs | 82 +++--- nac3core/src/codegen/types/ndarray.rs | 244 ++++++++++++----- nac3core/src/codegen/values/ndarray.rs | 352 ++++++++++++++++++------- 6 files changed, 506 insertions(+), 224 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 77340627..dd58ac88 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1370,6 +1370,7 @@ fn polymorphic_print<'ctx>( let val = NDArrayValue::from_pointer_value( value.into_pointer_value(), llvm_elem_ty, + None, llvm_usize, None, ); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 20d35006..b35a3ec0 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -74,6 +74,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let arg = NDArrayValue::from_pointer_value( arg.into_pointer_value(), ctx.get_llvm_type(generator, elem_ty), + None, llvm_usize, None, ); @@ -153,7 +154,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.int32, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), )?; @@ -216,7 +217,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.int64, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), )?; @@ -295,7 +296,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.uint32, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), )?; @@ -363,7 +364,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.uint64, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), )?; @@ -430,7 +431,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), )?; @@ -477,7 +478,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -518,7 +519,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), )?; @@ -584,7 +585,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.bool, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| { let elem = call_bool(generator, ctx, (elem_ty, val))?; @@ -639,7 +640,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -690,7 +691,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -921,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); + let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx @@ -1135,7 +1136,7 @@ where ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None), |generator, ctx, elem_val| { helper_call_numpy_unary_elementwise( generator, @@ -1974,7 +1975,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2016,7 +2017,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( unimplemented!("{FN_NAME} operates on float type NdArrays only"); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2066,7 +2067,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() @@ -2121,7 +2122,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2163,7 +2164,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() @@ -2206,7 +2207,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() @@ -2259,7 +2260,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); // Changing second parameter to a `NDArray` for uniformity in function call let n2_array = numpy::create_ndarray_const_shape( generator, @@ -2354,7 +2355,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() @@ -2397,7 +2398,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None); let dim0 = unsafe { n1.shape() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 489eaa92..5f33aa66 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1570,12 +1570,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let left_val = NDArrayValue::from_pointer_value( left_val.into_pointer_value(), llvm_ndarray_dtype1, + None, llvm_usize, None, ); let right_val = NDArrayValue::from_pointer_value( right_val.into_pointer_value(), llvm_ndarray_dtype2, + None, llvm_usize, None, ); @@ -1631,6 +1633,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let ndarray_val = NDArrayValue::from_pointer_value( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), llvm_ndarray_dtype, + None, llvm_usize, None, ); @@ -1828,6 +1831,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let val = NDArrayValue::from_pointer_value( val.into_pointer_value(), llvm_ndarray_dtype, + None, llvm_usize, None, ); @@ -1926,6 +1930,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let left_val = NDArrayValue::from_pointer_value( lhs.into_pointer_value(), llvm_ndarray_dtype1, + None, llvm_usize, None, ); @@ -2799,6 +2804,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let ndarray = NDArrayValue::from_pointer_value( subscripted_ndarray, llvm_ndarray_data_t, + None, llvm_usize, None, ); @@ -3542,7 +3548,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None); }; - let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); + let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 58869f2c..92bab809 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -3,6 +3,7 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; +use itertools::Itertools; use nac3parser::ast::{Operator, StrRef}; @@ -27,7 +28,7 @@ use crate::{ symbol_resolver::ValueEnum, toplevel::{ helper::{arraylike_flatten_element_type, PrimDef}, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::unpack_ndarray_var_tys, DefinitionId, }, typecheck::{ @@ -43,19 +44,16 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( elem_ty: Type, ) -> Result, String> { let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray_t = ctx - .get_llvm_type(generator, ndarray_ty) - .into_pointer_type() + let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty) + .as_base_type() .get_element_type() .into_struct_type(); let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None)) + Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None)) } /// Creates an `NDArray` instance from a dynamic shape. @@ -189,28 +187,10 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( // TODO: Disallow dim_sz > u32_MAX } - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - - let num_dims = llvm_usize.const_int(shape.len() as u64, false); - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); - - for (i, &shape_dim) in shape.iter().enumerate() { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let ndarray_dim = unsafe { - ndarray.shape().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, true), - None, - ) - }; - - ctx.builder.build_store(ndarray_dim, shape_dim).unwrap(); - } + let llvm_dtype = ctx.get_llvm_type(generator, elem_ty); + let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype) + .construct_dyn_shape(generator, ctx, shape, None); let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); Ok(ndarray) @@ -338,20 +318,24 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( // Get the length/size of the tuple, which also happens to be the value of `ndims`. let ndims = shape_tuple.get_type().count_fields(); - let mut shape = Vec::with_capacity(ndims as usize); - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .unwrap() - .into_int_value(); + let shape = (0..ndims) + .map(|dim_i| { + ctx.builder + .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) + .map(BasicValueEnum::into_int_value) + .map(|v| { + ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap() + }) + .unwrap() + }) + .collect_vec(); - shape.push(dim); - } create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) } BasicValueEnum::IntValue(shape_int) => { // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + let shape_int = + ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap(); create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) } @@ -505,6 +489,7 @@ where let lhs_val = NDArrayValue::from_pointer_value( lhs_val.into_pointer_value(), llvm_lhs_elem_ty, + None, llvm_usize, None, ); @@ -517,6 +502,7 @@ where let rhs_val = NDArrayValue::from_pointer_value( rhs_val.into_pointer_value(), llvm_rhs_elem_ty, + None, llvm_usize, None, ); @@ -532,6 +518,7 @@ where let lhs = NDArrayValue::from_pointer_value( lhs_val.into_pointer_value(), llvm_lhs_elem_ty, + None, llvm_usize, None, ); @@ -548,6 +535,7 @@ where let rhs = NDArrayValue::from_pointer_value( rhs_val.into_pointer_value(), llvm_rhs_elem_ty, + None, llvm_usize, None, ); @@ -706,7 +694,8 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( { let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); - NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) + NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None) + .load_ndims(ctx) } BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { @@ -856,7 +845,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims if NDArrayValue::is_representable(object, llvm_usize).is_ok() { let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); + let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None); let ndarray = gen_if_else_expr_callback( generator, @@ -932,6 +921,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( return Ok(NDArrayValue::from_pointer_value( ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), llvm_elem_ty, + None, llvm_usize, None, )); @@ -1465,6 +1455,7 @@ where let lhs_val = NDArrayValue::from_pointer_value( lhs_val.into_pointer_value(), llvm_lhs_elem_ty, + None, llvm_usize, None, ); @@ -1473,6 +1464,7 @@ where let rhs_val = NDArrayValue::from_pointer_value( rhs_val.into_pointer_value(), llvm_rhs_elem_ty, + None, llvm_usize, None, ); @@ -1499,6 +1491,7 @@ where let ndarray = NDArrayValue::from_pointer_value( if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), llvm_elem_ty, + None, llvm_usize, None, ); @@ -2061,6 +2054,7 @@ pub fn gen_ndarray_copy<'ctx>( NDArrayValue::from_pointer_value( this_arg.into_pointer_value(), llvm_elem_ty, + None, llvm_usize, None, ), @@ -2098,7 +2092,7 @@ pub fn gen_ndarray_fill<'ctx>( ndarray_fill_flattened( generator, context, - NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), + NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None), |generator, ctx, _| { let value = if value_arg.is_pointer_value() { let llvm_i1 = ctx.ctx.bool_type(); @@ -2140,7 +2134,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); // Dimensions are reversed in the transposed array @@ -2260,7 +2254,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2548,8 +2542,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); - let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None); + let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 46ffb5b8..b4f79f7a 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -1,6 +1,6 @@ use inkwell::context::ContextRef; use inkwell::{ - context::{AsContextRef, Context}, + context::{AsContextRef, Context, ContextRef}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, values::{IntValue, PointerValue}, AddressSpace, @@ -11,9 +11,13 @@ use super::{ structure::{FieldIndexCounter, StructField, StructFields}, ProxyType, }; -use crate::codegen::{ - values::{ArraySliceValue, NDArrayValue, ProxyValue}, - {CodeGenContext, CodeGenerator}, +use crate::{ + codegen::{ + values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator}, + {CodeGenContext, CodeGenerator}, + }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, + typecheck::typedef::Type, }; /// Proxy type for a `ndarray` type in LLVM. @@ -21,13 +25,17 @@ use crate::codegen::{ pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, + // TODO(Derppening): Make this non-optional + ndims: Option, llvm_usize: IntType<'ctx>, } #[derive(PartialEq, Eq, Clone, Copy)] pub struct NDArrayStructFields<'ctx> { + pub itemsize: StructField<'ctx, IntValue<'ctx>>, pub ndims: StructField<'ctx, IntValue<'ctx>>, pub shape: StructField<'ctx, PointerValue<'ctx>>, + pub strides: StructField<'ctx, PointerValue<'ctx>>, pub data: StructField<'ctx, PointerValue<'ctx>>, } @@ -37,12 +45,18 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { let mut counter = FieldIndexCounter::default(); NDArrayStructFields { + itemsize: StructField::create(&mut counter, "itemsize", llvm_usize), ndims: StructField::create(&mut counter, "ndims", llvm_usize), shape: StructField::create( &mut counter, "shape", llvm_usize.ptr_type(AddressSpace::default()), ), + strides: StructField::create( + &mut counter, + "strides", + llvm_usize.ptr_type(AddressSpace::default()), + ), data: StructField::create( &mut counter, "data", @@ -52,7 +66,13 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { } fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { - vec![self.ndims.into(), self.shape.into(), self.data.into()] + vec![ + self.itemsize.into(), + self.ndims.into(), + self.shape.into(), + self.strides.into(), + self.data.into(), + ] } } @@ -62,70 +82,45 @@ impl<'ctx> NDArrayType<'ctx> { llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec(); + let llvm_ndarray_ty = llvm_ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); }; - if llvm_ndarray_ty.count_fields() != 3 { + if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() { return Err(format!( - "Expected 3 fields in `NDArray`, got {}", + "Expected {} fields in `NDArray`, got {}", + llvm_expected_ty.len(), llvm_ndarray_ty.count_fields() )); } - let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); - let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { - return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); - }; - if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `ndarray.0`, got {}-bit int", - llvm_usize.get_bit_width(), - ndarray_ndims_ty.get_bit_width() - )); - } - - let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); - let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { - return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); - }; - let ndarray_dims = ndarray_pdims.get_element_type(); - let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { - return Err(format!( - "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" - )); - }; - if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - llvm_usize.get_bit_width(), - ndarray_dims.get_bit_width() - )); - } - - let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else { - return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); - }; - let ndarray_data = ndarray_pdata.get_element_type(); - let Ok(ndarray_data) = IntType::try_from(ndarray_data) else { - return Err(format!( - "Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}" - )); - }; - if ndarray_data.get_bit_width() != 8 { - return Err(format!( - "Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - ndarray_data.get_bit_width() - )); - } + llvm_expected_ty + .iter() + .enumerate() + .map(|(i, expected_ty)| { + (expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap()) + }) + .try_for_each(|(expected_ty, actual_ty)| { + if expected_ty == actual_ty { + Ok(()) + } else { + Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}")) + } + })?; Ok(()) } // TODO: Move this into e.g. StructProxyType #[must_use] - fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> NDArrayStructFields<'ctx> { NDArrayStructFields::new(ctx, llvm_usize) } @@ -133,7 +128,7 @@ impl<'ctx> NDArrayType<'ctx> { #[must_use] pub fn get_fields( &self, - ctx: &'ctx Context, + ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>, ) -> NDArrayStructFields<'ctx> { Self::fields(ctx, llvm_usize) @@ -142,7 +137,7 @@ impl<'ctx> NDArrayType<'ctx> { /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { - // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } + // struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* } // // * data : Pointer to an array containing the array data // * itemsize: The size of each NDArray elements in bytes @@ -165,7 +160,28 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType { ty: llvm_ndarray, dtype, llvm_usize } + NDArrayType { ty: llvm_ndarray, dtype, ndims: None, llvm_usize } + } + + /// Creates an [`NDArrayType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let llvm_dtype = ctx.get_llvm_type(generator, dtype); + let llvm_usize = generator.get_size_type(ctx.ctx); + + NDArrayType { + ty: Self::llvm_type(ctx.ctx, llvm_usize), + dtype: llvm_dtype, + ndims: Some(ndims), + llvm_usize, + } } /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. @@ -177,7 +193,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - NDArrayType { ty: ptr_ty, dtype, llvm_usize } + NDArrayType { ty: ptr_ty, dtype, ndims: None, llvm_usize } } /// Returns the type of the `size` field of this `ndarray` type. @@ -186,7 +202,7 @@ impl<'ctx> NDArrayType<'ctx> { self.as_base_type() .get_element_type() .into_struct_type() - .get_field_type_at_index(0) + .get_field_type_at_index(1) .map(BasicTypeEnum::into_int_type) .unwrap() } @@ -196,6 +212,114 @@ impl<'ctx> NDArrayType<'ctx> { pub fn element_type(&self) -> BasicTypeEnum<'ctx> { self.dtype } + + /// Returns the number of dimensions represented by this [`NDArrayType`], or [`None`] if it is + /// not known. + #[must_use] + pub fn ndims_as_value(&self) -> Option> { + self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) + } + + /// Allocate an ndarray on the stack given its `ndims` and `dtype`. + /// + /// `shape` and `strides` will be automatically allocated onto the stack. + /// + /// The returned ndarray's content will be: + /// - `data`: uninitialized. + /// - `itemsize`: set to the `sizeof()` of `dtype`. + /// - `ndims`: set to the value of `ndims`. + /// - `shape`: allocated with an array of length `ndims` with uninitialized values. + /// - `strides`: allocated with an array of length `ndims` with uninitialized values. + #[must_use] + pub fn construct_uninitialized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndims: Option, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.new_value(generator, ctx, name); + + let itemsize = ctx + .builder + .build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") + .unwrap(); + ndarray.store_itemsize(ctx, generator, itemsize); + + let ndims_val = self.llvm_usize.const_int(ndims.or(self.ndims).unwrap(), false); + ndarray.store_ndims(ctx, generator, ndims_val); + + ndarray.create_shape(ctx, self.llvm_usize, ndims_val); + ndarray.create_strides(ctx, self.llvm_usize, ndims_val); + + ndarray + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + #[must_use] + pub fn construct_const_shape( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &[u64], + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name); + + // Write shape + let ndarray_shape = ndarray.shape(); + for (i, dim) in shape.iter().enumerate() { + let dim = self.llvm_usize.const_int(*dim, false); + unsafe { + ndarray_shape.set_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i as u64, false), + dim, + ); + } + } + + ndarray + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + #[must_use] + pub fn construct_dyn_shape( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &[IntValue<'ctx>], + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name); + + // Write shape + let ndarray_shape = ndarray.shape(); + for (i, dim) in shape.iter().enumerate() { + assert_eq!( + dim.get_type(), + self.llvm_usize, + "Expected {} but got {}", + self.llvm_usize.print_to_string(), + dim.get_type().print_to_string() + ); + unsafe { + ndarray_shape.set_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i as u64, false), + *dim, + ); + } + } + + ndarray + } } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { @@ -264,7 +388,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); - NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name) + NDArrayValue::from_pointer_value(value, self.dtype, self.ndims, self.llvm_usize, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index d5bedc4c..e4bce3e9 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -21,6 +21,7 @@ use crate::codegen::{ pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, + ndims: Option, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -40,12 +41,13 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn from_pointer_value( ptr: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, + ndims: Option, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); - NDArrayValue { value: ptr, dtype, llvm_usize, name } + NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. @@ -75,6 +77,33 @@ impl<'ctx> NDArrayValue<'ctx> { ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() } + /// Returns the pointer to the field storing the size of each element of this `NDArray`. + fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .itemsize + .ptr_by_gep(ctx, self.value, self.name) + } + + /// Stores the size of each element `itemsize` into this instance. + pub fn store_itemsize( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_store(pndims, ndims).unwrap(); + } + + /// Returns the size of each element of this `NDArray` as a value. + pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + } + /// Returns the double-indirection pointer to the `shape` array, as if by calling /// `getelementptr` on the field. fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -105,6 +134,36 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayShapeProxy(self) } + /// Returns the double-indirection pointer to the `stride` array, as if by calling + /// `getelementptr` on the field. + fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .strides + .ptr_by_gep(ctx, self.value, self.name) + } + + /// Stores the array of dimension sizes `dims` into this instance. + fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); + } + + /// Convenience method for creating a new array storing the stride with the given `size`. + pub fn create_strides( + &self, + ctx: &CodeGenContext<'ctx, '_>, + llvm_usize: IntType<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); + } + + /// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`. + #[must_use] + pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> { + NDArrayStridesProxy(self) + } + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -168,103 +227,6 @@ impl<'ctx> From> for PointerValue<'ctx> { } } -/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.shape().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.0.load_ndims(ctx) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} - -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { - fn downcast_to_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> IntValue<'ctx> { - value.into_int_value() - } -} - -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { - fn upcast_from_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: IntValue<'ctx>, - ) -> BasicValueEnum<'ctx> { - value.into() - } -} - /// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. #[derive(Copy, Clone)] pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); @@ -515,3 +477,197 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, for NDArrayDataProxy<'ctx, '_> { } + +/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); + +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.shape().base_ptr(ctx, generator).get_type().get_element_type() + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder + .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> IntValue<'ctx> { + self.0.load_ndims(ctx) + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let size = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(*idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} + +impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { + fn downcast_to_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> IntValue<'ctx> { + value.into_int_value() + } +} + +impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { + fn upcast_from_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + value.into() + } +} + +/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); + +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.shape().base_ptr(ctx, generator).get_type().get_element_type() + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder + .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> IntValue<'ctx> { + self.0.load_ndims(ctx) + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let size = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(*idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} + +impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { + fn downcast_to_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> IntValue<'ctx> { + value.into_int_value() + } +} + +impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { + fn upcast_from_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + value.into() + } +}