[core] coregen/types: Implement StructFields for NDArray

Also rename some fields to better align with their naming in numpy.
This commit is contained in:
David Mak 2024-11-13 15:53:29 +08:00
parent c50ed30f75
commit 2d8b7a262c
7 changed files with 165 additions and 147 deletions

View File

@ -498,7 +498,7 @@ fn format_rpc_arg<'ctx>(
call_memcpy_generic(
ctx,
pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator),
llvm_arg.shape().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
@ -612,7 +612,7 @@ fn format_rpc_ret<'ctx>(
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
@ -702,7 +702,7 @@ fn format_rpc_ret<'ctx>(
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
ndarray.shape().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
@ -714,7 +714,7 @@ fn format_rpc_ret<'ctx>(
// `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate)
let num_elements =
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
@ -1379,7 +1379,7 @@ fn polymorphic_print<'ctx>(
llvm_usize,
None,
);
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();

View File

@ -78,7 +78,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
None,
);
let ndims = arg.dim_sizes().size(ctx, generator);
let ndims = arg.shape().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
@ -91,12 +91,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
@ -927,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
@ -1981,12 +1976,12 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2023,12 +2018,12 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2074,12 +2069,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2128,12 +2123,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2171,12 +2166,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2214,12 +2209,12 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2284,12 +2279,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
@ -2362,7 +2357,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
@ -2405,7 +2400,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};

View File

@ -2631,7 +2631,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe {
v.dim_sizes().get_typed_unchecked(
v.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, true),
@ -2672,7 +2672,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes().get_typed_unchecked(
v.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
@ -2813,7 +2813,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx
.builder
@ -2824,7 +2824,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
)
.unwrap();
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
v.shape().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
@ -2833,7 +2833,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
ndarray.shape().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
@ -2845,7 +2845,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
let ndarray_num_elems = ctx

View File

@ -103,7 +103,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
@ -172,7 +172,7 @@ where
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
@ -259,8 +259,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
@ -298,9 +298,9 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
@ -362,7 +362,7 @@ pub fn call_ndarray_calc_broadcast_index<
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)

View File

@ -128,7 +128,7 @@ where
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -144,7 +144,7 @@ where
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_pdim =
unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) };
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(
ndarray.shape().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, true),
@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
@ -380,7 +380,7 @@ where
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
@ -739,7 +739,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
let stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.dim_sizes(),
&dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
@ -1155,7 +1155,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
&src_arr.shape(),
(Some(llvm_usize.const_int(dim, false)), None),
);
let stride =
@ -1173,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let src_stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
&src_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let dst_stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.dim_sizes(),
&dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
@ -1278,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
&this,
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|generator, ctx, shape, idx| unsafe {
Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)?
} else {
@ -1286,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
let ndims = this.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
ndarray.create_shape(ctx, llvm_usize, ndims);
// Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() {
@ -1318,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
unsafe {
ndarray.dim_sizes().set_typed_unchecked(
ndarray.shape().set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
@ -1336,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
(this.load_ndims(ctx), false),
|generator, ctx, _, idx| {
unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz);
}
Ok(())
@ -1397,7 +1397,7 @@ where
&operand,
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
@ -1510,7 +1510,7 @@ where
&ndarray,
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
@ -1571,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if let Some(res) = res {
let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe {
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let res_dim1 = unsafe {
res.dim_sizes().get_typed_unchecked(
res.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
@ -1582,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
)
};
let lhs_dim0 = unsafe {
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let rhs_dim1 = unsafe {
rhs.dim_sizes().get_typed_unchecked(
rhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
@ -1634,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let rhs_dim0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
// lhs.dims[1] == rhs.dims[0]
@ -1681,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
},
|generator, ctx| {
Ok(Some(unsafe {
lhs.dim_sizes().get_typed_unchecked(
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
@ -1691,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
},
|generator, ctx| {
Ok(Some(unsafe {
rhs.dim_sizes().get_typed_unchecked(
rhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
@ -1718,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
let common_dim = {
let lhs_idx1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
@ -1726,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
)
};
let rhs_idx0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
@ -2146,7 +2141,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
@ -2161,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
@ -2198,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
@ -2266,7 +2261,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2494,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
@ -2556,8 +2551,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
ctx.make_assert(
generator,

View File

@ -1,11 +1,16 @@
use inkwell::context::ContextRef;
use inkwell::{
context::Context,
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use super::ProxyType;
use super::{
structure::{FieldIndexCounter, StructField, StructFields},
ProxyType,
};
use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator},
@ -19,6 +24,38 @@ pub struct NDArrayType<'ctx> {
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> {
pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = FieldIndexCounter::default();
NDArrayStructFields {
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
}
}
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![self.ndims.into(), self.shape.into(), self.data.into()]
}
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
@ -86,19 +123,34 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(())
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data : Pointer to an array containing the array data
let field_tys = [
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
ctx.i8_type().ptr_type(AddressSpace::default()).into(),
];
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}

View File

@ -50,18 +50,10 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
@ -83,59 +75,43 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
)
.unwrap()
}
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of data elements `data` into this instance.
@ -194,15 +170,15 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type()
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
@ -213,7 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
@ -227,7 +203,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -266,10 +242,10 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
@ -279,7 +255,7 @@ impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ct
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
@ -491,7 +467,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
let (dim_idx, dim_sz) = unsafe {
(
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None),
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx