[core] Add itemsize and strides to NDArray struct

This commit is contained in:
David Mak 2024-11-15 15:24:07 +08:00
parent 144f0922db
commit 3808690b78
6 changed files with 482 additions and 221 deletions

View File

@ -1376,6 +1376,7 @@ fn polymorphic_print<'ctx>(
let val = NDArrayValue::from_pointer_value( let val = NDArrayValue::from_pointer_value(
value.into_pointer_value(), value.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );

View File

@ -74,6 +74,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
let arg = NDArrayValue::from_pointer_value( let arg = NDArrayValue::from_pointer_value(
arg.into_pointer_value(), arg.into_pointer_value(),
ctx.get_llvm_type(generator, elem_ty), ctx.get_llvm_type(generator, elem_ty),
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -153,7 +154,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int32, ctx.primitives.int32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -216,7 +217,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int64, ctx.primitives.int64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -295,7 +296,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint32, ctx.primitives.uint32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -363,7 +364,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint64, ctx.primitives.uint64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -430,7 +431,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
)?; )?;
@ -477,7 +478,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -518,7 +519,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
)?; )?;
@ -584,7 +585,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.bool, ctx.primitives.bool,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| { |generator, ctx, val| {
let elem = call_bool(generator, ctx, (elem_ty, val))?; let elem = call_bool(generator, ctx, (elem_ty, val))?;
@ -639,7 +640,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -690,7 +691,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -921,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx let n_sz_eqz = ctx
@ -1135,7 +1136,7 @@ where
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None),
|generator, ctx, elem_val| { |generator, ctx, elem_val| {
helper_call_numpy_unary_elementwise( helper_call_numpy_unary_elementwise(
generator, generator,
@ -1974,7 +1975,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2016,7 +2017,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unimplemented!("{FN_NAME} operates on float type NdArrays only");
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2066,7 +2067,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2121,7 +2122,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2163,7 +2164,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2206,7 +2207,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2259,7 +2260,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call // Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape( let n2_array = numpy::create_ndarray_const_shape(
generator, generator,
@ -2354,7 +2355,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2397,7 +2398,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()

View File

@ -1570,12 +1570,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
left_val.into_pointer_value(), left_val.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
None,
llvm_usize, llvm_usize,
None, None,
); );
let right_val = NDArrayValue::from_pointer_value( let right_val = NDArrayValue::from_pointer_value(
right_val.into_pointer_value(), right_val.into_pointer_value(),
llvm_ndarray_dtype2, llvm_ndarray_dtype2,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1631,6 +1633,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let ndarray_val = NDArrayValue::from_pointer_value( let ndarray_val = NDArrayValue::from_pointer_value(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1828,6 +1831,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
let val = NDArrayValue::from_pointer_value( let val = NDArrayValue::from_pointer_value(
val.into_pointer_value(), val.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1926,6 +1930,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
lhs.into_pointer_value(), lhs.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -2799,6 +2804,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray = NDArrayValue::from_pointer_value( let ndarray = NDArrayValue::from_pointer_value(
subscripted_ndarray, subscripted_ndarray,
llvm_ndarray_data_t, llvm_ndarray_data_t,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -3547,7 +3553,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
} }

View File

@ -3,6 +3,7 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools;
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
@ -27,7 +28,7 @@ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, PrimDef}, helper::{arraylike_flatten_element_type, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::unpack_ndarray_var_tys,
DefinitionId, DefinitionId,
}, },
typecheck::{ typecheck::{
@ -43,19 +44,16 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type, elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
.get_llvm_type(generator, ndarray_ty) .as_base_type()
.into_pointer_type()
.get_element_type() .get_element_type()
.into_struct_type(); .into_struct_type();
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None)) Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None))
} }
/// Creates an `NDArray` instance from a dynamic shape. /// Creates an `NDArray` instance from a dynamic shape.
@ -189,28 +187,10 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
// TODO: Disallow dim_sz > u32_MAX // TODO: Disallow dim_sz > u32_MAX
} }
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.shape().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, true),
None,
)
};
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype)
.construct_dyn_shape(generator, ctx, shape, None);
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray) Ok(ndarray)
@ -338,20 +318,24 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
// Get the length/size of the tuple, which also happens to be the value of `ndims`. // Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields(); let ndims = shape_tuple.get_type().count_fields();
let mut shape = Vec::with_capacity(ndims as usize); let shape = (0..ndims)
for dim_i in 0..ndims { .map(|dim_i| {
let dim = ctx ctx.builder
.builder .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) .map(BasicValueEnum::into_int_value)
.unwrap() .map(|v| {
.into_int_value(); ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
})
.unwrap()
})
.collect_vec();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
} }
BasicValueEnum::IntValue(shape_int) => { BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let shape_int =
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
@ -505,6 +489,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -517,6 +502,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -532,6 +518,7 @@ where
let lhs = NDArrayValue::from_pointer_value( let lhs = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -548,6 +535,7 @@ where
let rhs = NDArrayValue::from_pointer_value( let rhs = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -706,7 +694,8 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
{ {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
.load_ndims(ctx)
} }
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
@ -856,7 +845,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_representable(object, llvm_usize).is_ok() { if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
let ndarray = gen_if_else_expr_callback( let ndarray = gen_if_else_expr_callback(
generator, generator,
@ -932,6 +921,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
return Ok(NDArrayValue::from_pointer_value( return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
)); ));
@ -1465,6 +1455,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1473,6 +1464,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1499,6 +1491,7 @@ where
let ndarray = NDArrayValue::from_pointer_value( let ndarray = NDArrayValue::from_pointer_value(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -2061,6 +2054,7 @@ pub fn gen_ndarray_copy<'ctx>(
NDArrayValue::from_pointer_value( NDArrayValue::from_pointer_value(
this_arg.into_pointer_value(), this_arg.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
), ),
@ -2098,7 +2092,7 @@ pub fn gen_ndarray_fill<'ctx>(
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
context, context,
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, _| { |generator, ctx, _| {
let value = if value_arg.is_pointer_value() { let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
@ -2140,7 +2134,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
@ -2260,7 +2254,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2548,8 +2542,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));

View File

@ -1,5 +1,5 @@
use inkwell::{ use inkwell::{
context::Context, context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue}, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
@ -12,9 +12,13 @@ use super::{
structure::{StructField, StructFields}, structure::{StructField, StructFields},
ProxyType, ProxyType,
}; };
use crate::codegen::{ use crate::{
values::{ArraySliceValue, NDArrayValue, ProxyValue}, codegen::{
{CodeGenContext, CodeGenerator}, values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
{CodeGenContext, CodeGenerator},
},
toplevel::numpy::unpack_ndarray_var_tys,
typecheck::typedef::Type,
}; };
/// Proxy type for a `ndarray` type in LLVM. /// Proxy type for a `ndarray` type in LLVM.
@ -27,10 +31,14 @@ pub struct NDArrayType<'ctx> {
#[derive(PartialEq, Eq, Clone, Copy, StructFields)] #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> { pub struct NDArrayStructFields<'ctx> {
#[value_type(usize)]
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize)] #[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>, pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))] #[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>, pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub strides: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))] #[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>, pub data: StructField<'ctx, PointerValue<'ctx>>,
} }
@ -41,70 +49,45 @@ impl<'ctx> NDArrayType<'ctx> {
llvm_ty: PointerType<'ctx>, llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type(); let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
}; };
if llvm_ndarray_ty.count_fields() != 3 { if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!( return Err(format!(
"Expected 3 fields in `NDArray`, got {}", "Expected {} fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields() llvm_ndarray_ty.count_fields()
)); ));
} }
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); llvm_expected_ty
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { .iter()
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); .enumerate()
}; .map(|(i, expected_ty)| {
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { (expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
return Err(format!( })
"Expected {}-bit int type for `ndarray.0`, got {}-bit int", .try_for_each(|(expected_ty, actual_ty)| {
llvm_usize.get_bit_width(), if expected_ty == actual_ty {
ndarray_ndims_ty.get_bit_width() Ok(())
)); } else {
} Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); })?;
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
let ndarray_data = ndarray_pdata.get_element_type();
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
));
};
if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType // TODO: Move this into e.g. StructProxyType
#[must_use] #[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize) NDArrayStructFields::new(ctx, llvm_usize)
} }
@ -112,7 +95,7 @@ impl<'ctx> NDArrayType<'ctx> {
#[must_use] #[must_use]
pub fn get_fields( pub fn get_fields(
&self, &self,
ctx: &'ctx Context, ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> { ) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize) Self::fields(ctx, llvm_usize)
@ -121,7 +104,7 @@ impl<'ctx> NDArrayType<'ctx> {
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use] #[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
// //
// * data : Pointer to an array containing the array data // * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes // * itemsize: The size of each NDArray elements in bytes
@ -147,6 +130,25 @@ impl<'ctx> NDArrayType<'ctx> {
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize } NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
} }
/// Creates an [`NDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
NDArrayType {
ty: Self::llvm_type(ctx.ctx, llvm_usize),
dtype: llvm_dtype,
llvm_usize,
}
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use] #[must_use]
pub fn from_type( pub fn from_type(
@ -165,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> {
self.as_base_type() self.as_base_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type()
.get_field_type_at_index(0) .get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type) .map(BasicTypeEnum::into_int_type)
.unwrap() .unwrap()
} }
@ -175,6 +177,107 @@ impl<'ctx> NDArrayType<'ctx> {
pub fn element_type(&self) -> BasicTypeEnum<'ctx> { pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype self.dtype
} }
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated onto the stack.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
#[must_use]
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: u64,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.new_value(generator, ctx, name);
let itemsize = ctx
.builder
.build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap();
ndarray.store_itemsize(ctx, generator, itemsize);
let ndims_val = self.llvm_usize.const_int(ndims, false);
ndarray.store_ndims(ctx, generator, ndims_val);
ndarray.create_shape(ctx, self.llvm_usize, ndims_val);
ndarray.create_strides(ctx, self.llvm_usize, ndims_val);
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[u64],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
let dim = self.llvm_usize.const_int(*dim, false);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
dim,
);
}
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[IntValue<'ctx>],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
assert_eq!(
dim.get_type(),
self.llvm_usize,
"Expected {} but got {}",
self.llvm_usize.print_to_string(),
dim.get_type().print_to_string()
);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
*dim,
);
}
}
ndarray
}
} }
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
@ -243,7 +346,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
) -> Self::Value { ) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type()); debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name) NDArrayValue::from_pointer_value(value, self.dtype, None, self.llvm_usize, name)
} }
fn as_base_type(&self) -> Self::Base { fn as_base_type(&self) -> Self::Base {

View File

@ -21,6 +21,7 @@ use crate::codegen::{
pub struct NDArrayValue<'ctx> { pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>, value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
} }
@ -40,12 +41,13 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn from_pointer_value( pub fn from_pointer_value(
ptr: PointerValue<'ctx>, ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self { ) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, llvm_usize, name } NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
@ -75,6 +77,33 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
} }
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.itemsize
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling /// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field. /// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
@ -105,6 +134,36 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayShapeProxy(self) NDArrayShapeProxy(self)
} }
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.strides
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
@ -168,103 +227,6 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
} }
} }
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. /// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
@ -521,3 +483,197 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
for NDArrayDataProxy<'ctx, '_> for NDArrayDataProxy<'ctx, '_>
{ {
} }
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}