From d6b884664e3394dcc43fb1b9bb64bf3ca2c0703f Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 28 Aug 2024 16:33:03 +0800 Subject: [PATCH] [core] codegen/types: Implement NDArray in terms of i8* Better aligns with the future implementation of ndstrides. --- nac3artiq/src/codegen.rs | 12 +- nac3core/src/codegen/builtin_fns.rs | 107 +++++++++------ nac3core/src/codegen/expr.rs | 66 ++++++--- nac3core/src/codegen/numpy.rs | 181 +++++++++++++++++++------ nac3core/src/codegen/types/ndarray.rs | 49 ++++--- nac3core/src/codegen/values/ndarray.rs | 102 ++++++++++++-- nac3standalone/demo/src/ndarray.py | 3 +- 7 files changed, 372 insertions(+), 148 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 1fcfd4b7..aece9263 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -461,8 +461,7 @@ fn format_rpc_arg<'ctx>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); - let llvm_arg = - NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); + let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None); let llvm_usize_sizeof = ctx .builder @@ -1369,12 +1368,17 @@ fn polymorphic_print<'ctx>( TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); - let val = - NDArrayValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None); + let val = NDArrayValue::from_pointer_value( + value.into_pointer_value(), + llvm_elem_ty, + llvm_usize, + None, + ); let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 77657535..e693faff 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -21,7 +21,10 @@ use super::{ CodeGenContext, CodeGenerator, }; use crate::{ - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys}, + toplevel::{ + helper::{arraylike_flatten_element_type, PrimDef}, + numpy::unpack_ndarray_var_tys, + }, typecheck::typedef::{Type, TypeEnum}, }; @@ -65,10 +68,15 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = - NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); + let arg = NDArrayValue::from_pointer_value( + arg.into_pointer_value(), + ctx.get_llvm_type(generator, elem_ty), + llvm_usize, + None, + ); let ndims = arg.dim_sizes().size(ctx, generator); ctx.make_assert( @@ -143,13 +151,14 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.int32, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), )?; @@ -205,13 +214,14 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.int64, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), )?; @@ -283,13 +293,14 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.uint32, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), )?; @@ -350,13 +361,14 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.uint64, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), )?; @@ -416,13 +428,14 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), )?; @@ -462,13 +475,14 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -502,13 +516,14 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), )?; @@ -567,13 +582,14 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.bool, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| { let elem = call_bool(generator, ctx, (elem_ty, val))?; @@ -621,13 +637,14 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -671,13 +688,14 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -806,8 +824,8 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -906,9 +924,9 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_pointer_value(n, llvm_usize, None); + let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx @@ -926,7 +944,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + let accumulator_addr = generator.gen_var_alloc(ctx, llvm_elem_ty, None)?; let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; unsafe { @@ -1068,8 +1086,8 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1114,6 +1132,7 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); + let llvm_arg_elem_ty = ctx.get_llvm_type(generator, arg_elem_ty); let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1121,7 +1140,7 @@ where ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(x, llvm_usize, None), + NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), |generator, ctx, elem_val| { helper_call_numpy_unary_elementwise( generator, @@ -1508,8 +1527,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1575,8 +1594,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1642,8 +1661,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1709,8 +1728,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1765,8 +1784,8 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1832,8 +1851,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1899,8 +1918,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1960,7 +1979,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2002,7 +2021,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( unimplemented!("{FN_NAME} operates on float type NdArrays only"); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2052,7 +2071,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2107,7 +2126,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2149,7 +2168,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2192,7 +2211,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2245,7 +2264,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); // Changing second parameter to a `NDArray` for uniformity in function call let n2_array = numpy::create_ndarray_const_shape( generator, @@ -2340,7 +2359,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2383,7 +2402,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 93720f9a..01047b30 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1564,10 +1564,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let left_val = - NDArrayValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None); - let right_val = - NDArrayValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None); + let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1); + let llvm_ndarray_dtype2 = ctx.get_llvm_type(generator, ndarray_dtype2); + + let left_val = NDArrayValue::from_pointer_value( + left_val.into_pointer_value(), + llvm_ndarray_dtype1, + llvm_usize, + None, + ); + let right_val = NDArrayValue::from_pointer_value( + right_val.into_pointer_value(), + llvm_ndarray_dtype2, + llvm_usize, + None, + ); let res = if op.base == Operator::MatMult { // MatMult is the only binop which is not an elementwise op @@ -1591,8 +1602,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( BinopVariant::Normal => None, BinopVariant::AugAssign => Some(left_val), }, - (left_val.as_base_value().into(), false), - (right_val.as_base_value().into(), false), + (ty1, left_val.as_base_value().into(), false), + (ty2, right_val.as_base_value().into(), false), |generator, ctx, (lhs, rhs)| { gen_binop_expr_with_values( generator, @@ -1616,8 +1627,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else { let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); + let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_val = NDArrayValue::from_pointer_value( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), + llvm_ndarray_dtype, llvm_usize, None, ); @@ -1629,8 +1642,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( BinopVariant::Normal => None, BinopVariant::AugAssign => Some(ndarray_val), }, - (left_val, !is_ndarray1), - (right_val, !is_ndarray2), + (ty1, left_val, !is_ndarray1), + (ty2, right_val, !is_ndarray2), |generator, ctx, (lhs, rhs)| { gen_binop_expr_with_values( generator, @@ -1810,8 +1823,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); - let val = NDArrayValue::from_pointer_value(val.into_pointer_value(), llvm_usize, None); + let val = NDArrayValue::from_pointer_value( + val.into_pointer_value(), + llvm_ndarray_dtype, + llvm_usize, + None, + ); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1902,15 +1921,21 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let left_val = - NDArrayValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None); + let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1); + + let left_val = NDArrayValue::from_pointer_value( + lhs.into_pointer_value(), + llvm_ndarray_dtype1, + llvm_usize, + None, + ); let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, ctx.primitives.bool, None, - (left_val.as_base_value().into(), false), - (rhs, false), + (left_ty, left_val.as_base_value().into(), false), + (right_ty, rhs, false), |generator, ctx, (lhs, rhs)| { let val = gen_cmpop_expr_with_values( generator, @@ -1941,8 +1966,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx, ctx.primitives.bool, None, - (lhs, !is_ndarray1), - (rhs, !is_ndarray2), + (left_ty, lhs, !is_ndarray1), + (right_ty, rhs, !is_ndarray2), |generator, ctx, (lhs, rhs)| { let val = gen_cmpop_expr_with_values( generator, @@ -2771,8 +2796,12 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( // elements over let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = - NDArrayValue::from_pointer_value(subscripted_ndarray, llvm_usize, None); + let ndarray = NDArrayValue::from_pointer_value( + subscripted_ndarray, + llvm_ndarray_data_t, + llvm_usize, + None, + ); let num_dims = v.load_ndims(ctx); ndarray.store_ndims( @@ -3510,6 +3539,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); + let llvm_ty = ctx.get_llvm_type(generator, *ty); let v = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? @@ -3517,7 +3547,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None); }; - let v = NDArrayValue::from_pointer_value(v, usize, None); + let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 4589ba4d..5db4ac26 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -26,7 +26,7 @@ use super::{ use crate::{ symbol_resolver::ValueEnum, toplevel::{ - helper::PrimDef, + helper::{arraylike_flatten_element_type, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, @@ -42,6 +42,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, ) -> Result, String> { + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -54,7 +55,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - Ok(NDArrayValue::from_pointer_value(ndarray, llvm_usize, None)) + Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None)) } /// Creates an `NDArray` instance from a dynamic shape. @@ -473,8 +474,8 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, res: NDArrayValue<'ctx>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), + lhs: (Type, BasicValueEnum<'ctx>, bool), + rhs: (Type, BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> where @@ -487,8 +488,8 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; + let (lhs_ty, lhs_val, lhs_scalar) = lhs; + let (rhs_ty, rhs_val, rhs_scalar) = rhs; assert!( !(lhs_scalar && rhs_scalar), @@ -499,14 +500,26 @@ where // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { - let lhs_val = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs_val = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); } if !rhs_scalar { - let rhs_val = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs_val = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); } @@ -514,8 +527,14 @@ where let lhs_elem = if lhs_scalar { lhs_val } else { - let lhs = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } @@ -524,8 +543,14 @@ where let rhs_elem = if rhs_scalar { rhs_val } else { - let rhs = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } @@ -671,7 +696,7 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, + (ty, value): (Type, BasicValueEnum<'ctx>), ) -> IntValue<'ctx> { let llvm_usize = generator.get_size_type(ctx.ctx); @@ -679,7 +704,9 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(v) if NDArrayValue::is_representable(v, llvm_usize).is_ok() => { - NDArrayValue::from_pointer_value(v, llvm_usize, None).load_ndims(ctx) + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); + NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) } BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { @@ -694,7 +721,6 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), src_lst: ListValue<'ctx>, dim: u64, @@ -727,6 +753,20 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( |_, _| Ok(llvm_usize.const_int(1, false)), |generator, ctx, _, i| { let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); + let offset = ctx + .builder + .build_int_mul( + offset, + ctx.builder + .build_int_truncate_or_bit_cast( + dst_arr.get_type().element_type().size_of().unwrap(), + offset.get_type(), + "", + ) + .unwrap(), + "", + ) + .unwrap(); let dst_ptr = unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; @@ -741,7 +781,6 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_from_ndlist_impl( generator, ctx, - elem_ty, (dst_arr, dst_ptr), nested_lst_elem, dim + 1, @@ -760,7 +799,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( _ => { let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); + let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); let cpy_len = ctx @@ -816,7 +855,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let object = NDArrayValue::from_pointer_value(object, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); let ndarray = gen_if_else_expr_callback( generator, @@ -878,7 +918,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), (object, object.data().base_ptr(ctx, generator)), 0, @@ -892,6 +931,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( return Ok(NDArrayValue::from_pointer_value( ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), + llvm_elem_ty, llvm_usize, None, )); @@ -1026,7 +1066,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_from_ndlist_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), object, 0, @@ -1099,7 +1138,6 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), dim: u64, @@ -1108,10 +1146,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); + + let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); + // If there are no (remaining) slice expressions, memcpy the entire dimension if slices.is_empty() { - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let stride = call_ndarray_calc_size( generator, ctx, @@ -1162,9 +1202,29 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, _, src_i| { // Calculate the offset of the active slice let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); + let src_data_offset = ctx + .builder + .build_int_mul( + src_data_offset, + ctx.builder + .build_int_cast(sizeof_elem, src_data_offset.get_type(), "") + .unwrap(), + "", + ) + .unwrap(); let dst_i = ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); + let dst_data_offset = ctx + .builder + .build_int_mul( + dst_data_offset, + ctx.builder + .build_int_cast(sizeof_elem, dst_data_offset.get_type(), "") + .unwrap(), + "", + ) + .unwrap(); let (src_ptr, dst_ptr) = unsafe { ( @@ -1176,7 +1236,6 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (dst_arr, dst_ptr), (src_arr, src_ptr), dim + 1, @@ -1293,7 +1352,6 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), (this, this.data().base_ptr(ctx, generator)), 0, @@ -1376,8 +1434,8 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, res: Option>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), + lhs: (Type, BasicValueEnum<'ctx>, bool), + rhs: (Type, BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> where @@ -1390,8 +1448,8 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; + let (lhs_ty, lhs_val, lhs_scalar) = lhs; + let (rhs_ty, rhs_val, rhs_scalar) = rhs; assert!( !(lhs_scalar && rhs_scalar), @@ -1402,10 +1460,22 @@ where let ndarray = res.unwrap_or_else(|| { if lhs_scalar && rhs_scalar { - let lhs_val = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); - let rhs_val = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs_val = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs_val = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); @@ -1421,8 +1491,14 @@ where ) .unwrap() } else { + let dtype = arraylike_flatten_element_type( + &mut ctx.unifier, + if lhs_scalar { rhs_ty } else { lhs_ty }, + ); + let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); let ndarray = NDArrayValue::from_pointer_value( if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), + llvm_elem_ty, llvm_usize, None, ); @@ -1981,11 +2057,18 @@ pub fn gen_ndarray_copy<'ctx>( let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; + let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty); + ndarray_copy_impl( generator, context, this_elem_ty, - NDArrayValue::from_pointer_value(this_arg.into_pointer_value(), llvm_usize, None), + NDArrayValue::from_pointer_value( + this_arg.into_pointer_value(), + llvm_elem_ty, + llvm_usize, + None, + ), ) .map(NDArrayValue::into) } @@ -2004,6 +2087,7 @@ pub fn gen_ndarray_fill<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; + let this_elem_ty = arraylike_flatten_element_type(&mut context.unifier, this_ty); let this_arg = obj .as_ref() .unwrap() @@ -2014,10 +2098,12 @@ pub fn gen_ndarray_fill<'ctx>( let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; + let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty); + ndarray_fill_flattened( generator, context, - NDArrayValue::from_pointer_value(this_arg, llvm_usize, None), + NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), |generator, ctx, _| { let value = if value_arg.is_pointer_value() { let llvm_i1 = ctx.ctx.bool_type(); @@ -2058,7 +2144,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); // Dimensions are reversed in the transposed array @@ -2177,7 +2264,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2454,14 +2542,19 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; let (x1_ty, x1) = x1; - let (_, x2) = x2; + let (x2_ty, x2) = x2; let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); - let n2 = NDArrayValue::from_pointer_value(n2, llvm_usize, None); + let n1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let n2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); + let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); + let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); + + let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); + let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); @@ -2501,7 +2594,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( .build_float_mul(e1, elem2.into_float_value(), "") .unwrap() .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), }; let acc_val = ctx.builder.build_load(acc, "").unwrap(); let acc_val = match acc_val { @@ -2515,7 +2608,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( .build_float_add(e1, product.into_float_value(), "") .unwrap() .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), }; ctx.builder.build_store(acc, acc_val).unwrap(); diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 3f25f828..d6887322 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -67,21 +67,29 @@ impl<'ctx> NDArrayType<'ctx> { } let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(_) = PointerType::try_from(ndarray_data_ty) else { + let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else { return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); }; + let ndarray_data = ndarray_pdata.get_element_type(); + let Ok(ndarray_data) = IntType::try_from(ndarray_data) else { + return Err(format!( + "Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}" + )); + }; + if ndarray_data.get_bit_width() != 8 { + return Err(format!( + "Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + ndarray_data.get_bit_width() + )); + } Ok(()) } /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] - fn llvm_type( - ctx: &'ctx Context, - dtype: BasicTypeEnum<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> PointerType<'ctx> { - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // // * num_dims: Number of dimensions in the array // * dims: Pointer to an array containing the size of each dimension @@ -89,13 +97,13 @@ impl<'ctx> NDArrayType<'ctx> { let field_tys = [ llvm_usize.into(), llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), + ctx.i8_type().ptr_type(AddressSpace::default()).into(), ]; ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`ListType`]. + /// Creates an instance of [`NDArrayType`]. #[must_use] pub fn new( generator: &G, @@ -103,24 +111,21 @@ impl<'ctx> NDArrayType<'ctx> { dtype: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, dtype, llvm_usize); + let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType::from_type(llvm_ndarray, llvm_usize) + NDArrayType { ty: llvm_ndarray, dtype, llvm_usize } } - /// Creates an [`NDArrayType`] from a [`PointerType`]. + /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_type( + ptr_ty: PointerType<'ctx>, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - NDArrayType { - ty: ptr_ty, - dtype: ptr_ty - .get_element_type() - .try_into() - .expect("Expected BasicTypeEnum for dtype of NDArray"), - llvm_usize, - } + NDArrayType { ty: ptr_ty, dtype, llvm_usize } } /// Returns the type of the `size` field of this `ndarray` type. @@ -207,7 +212,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); - NDArrayValue::from_pointer_value(value, self.llvm_usize, name) + NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 908ad2f3..732ed0d3 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -1,7 +1,7 @@ use inkwell::{ - types::{AnyTypeEnum, BasicTypeEnum, IntType}, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue, PointerValue}, - IntPredicate, + AddressSpace, IntPredicate, }; use super::{ @@ -20,6 +20,7 @@ use crate::codegen::{ #[derive(Copy, Clone)] pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -38,12 +39,13 @@ impl<'ctx> NDArrayValue<'ctx> { #[must_use] pub fn from_pointer_value( ptr: PointerValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); - NDArrayValue { value: ptr, llvm_usize, name } + NDArrayValue { value: ptr, dtype, llvm_usize, name } } /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. @@ -138,6 +140,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + let data = ctx + .builder + .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") + .unwrap(); ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); } @@ -149,7 +155,15 @@ impl<'ctx> NDArrayValue<'ctx> { elem_ty: BasicTypeEnum<'ctx>, size: IntValue<'ctx>, ) { - self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); + let itemsize = + ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap(); + let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap(); + + // TODO: What about alignment? + self.store_data( + ctx, + ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(), + ); } /// Returns a proxy object to the field storing the data of this `NDArray`. @@ -164,7 +178,7 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) + NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -282,10 +296,10 @@ pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { fn element_type( &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + _: &CodeGenContext<'ctx, '_>, + _: &G, ) -> AnyTypeEnum<'ctx> { - self.0.data().base_ptr(ctx, generator).get_type().get_element_type() + self.0.dtype.as_any_type_enum() } fn base_ptr( @@ -318,15 +332,34 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - unsafe { + let sizeof_elem = ctx + .builder + .build_int_truncate_or_bit_cast( + self.element_type(ctx, generator).size_of().unwrap(), + idx.get_type(), + "", + ) + .unwrap(); + let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); + let ptr = unsafe { ctx.builder .build_in_bounds_gep( self.base_ptr(ctx, generator), - &[*idx], + &[idx], name.unwrap_or_default(), ) .unwrap() - } + }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } fn ptr_offset( @@ -347,7 +380,17 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx.current_loc, ); - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } } @@ -381,8 +424,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ); let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + let sizeof_elem = ctx + .builder + .build_int_truncate_or_bit_cast( + self.element_type(ctx, generator).size_of().unwrap(), + index.get_type(), + "", + ) + .unwrap(); + let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); - unsafe { + let ptr = unsafe { ctx.builder .build_in_bounds_gep( self.base_ptr(ctx, generator), @@ -390,7 +442,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> name.unwrap_or_default(), ) .unwrap() - } + }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } fn ptr_offset( @@ -455,7 +517,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ) .unwrap(); - unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } } diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 577ad9c3..d42f3b93 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -144,6 +144,7 @@ def test_ndarray_array(): # Copy n2_cpy: ndarray[float, 2] = np_array(n2, copy=False) + output_ndarray_float_2(n2_cpy) n2_cpy.fill(0.0) output_ndarray_float_2(n2_cpy) @@ -1756,7 +1757,7 @@ def run() -> int32: test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_transpose() test_ndarray_reshape() - + test_ndarray_dot() test_ndarray_cholesky() test_ndarray_qr()