[core] coregen/types: Implement StructFields for NDArray
Also rename some fields to better align with their naming in numpy.
This commit is contained in:
parent
b3b61959dd
commit
41fb27c913
|
@ -498,7 +498,7 @@ fn format_rpc_arg<'ctx>(
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
ctx,
|
ctx,
|
||||||
pbuffer_dims_begin,
|
pbuffer_dims_begin,
|
||||||
llvm_arg.dim_sizes().base_ptr(ctx, generator),
|
llvm_arg.shape().base_ptr(ctx, generator),
|
||||||
dims_buf_sz,
|
dims_buf_sz,
|
||||||
llvm_i1.const_zero(),
|
llvm_i1.const_zero(),
|
||||||
);
|
);
|
||||||
|
@ -612,7 +612,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
// Set `ndarray.ndims`
|
// Set `ndarray.ndims`
|
||||||
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||||
// Allocate `ndarray.shape` [size_t; ndims]
|
// Allocate `ndarray.shape` [size_t; ndims]
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||||
|
|
||||||
/*
|
/*
|
||||||
ndarray now:
|
ndarray now:
|
||||||
|
@ -702,7 +702,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
|
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
ctx,
|
ctx,
|
||||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
ndarray.shape().base_ptr(ctx, generator),
|
||||||
pbuffer_dims,
|
pbuffer_dims,
|
||||||
sizeof_dims,
|
sizeof_dims,
|
||||||
llvm_i1.const_zero(),
|
llvm_i1.const_zero(),
|
||||||
|
@ -714,7 +714,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
// `ndarray.shape` must be initialized beforehand in this implementation
|
// `ndarray.shape` must be initialized beforehand in this implementation
|
||||||
// (for ndarray.create_data() to know how many elements to allocate)
|
// (for ndarray.create_data() to know how many elements to allocate)
|
||||||
let num_elements =
|
let num_elements =
|
||||||
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
|
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
|
||||||
|
|
||||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
@ -1373,7 +1373,7 @@ fn polymorphic_print<'ctx>(
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
|
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
|
||||||
let last =
|
let last =
|
||||||
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
let ndims = arg.dim_sizes().size(ctx, generator);
|
let ndims = arg.shape().size(ctx, generator);
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder
|
ctx.builder
|
||||||
|
@ -91,12 +91,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
);
|
);
|
||||||
|
|
||||||
let len = unsafe {
|
let len = unsafe {
|
||||||
arg.dim_sizes().get_typed_unchecked(
|
arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
||||||
|
@ -927,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx
|
let n_sz_eqz = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -1981,12 +1976,12 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2023,12 +2018,12 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2074,12 +2069,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2128,12 +2123,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2171,12 +2166,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2214,12 +2209,12 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let dim1 = unsafe {
|
let dim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2284,12 +2279,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||||
|
|
||||||
let outdim0 = unsafe {
|
let outdim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
let outdim1 = unsafe {
|
let outdim1 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2362,7 +2357,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
@ -2405,7 +2400,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
};
|
};
|
||||||
|
|
|
@ -2631,7 +2631,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
let len = unsafe {
|
let len = unsafe {
|
||||||
v.dim_sizes().get_typed_unchecked(
|
v.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(dim, true),
|
&llvm_usize.const_int(dim, true),
|
||||||
|
@ -2672,7 +2672,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
ExprKind::Slice { lower, upper, step } => {
|
ExprKind::Slice { lower, upper, step } => {
|
||||||
let dim_sz = unsafe {
|
let dim_sz = unsafe {
|
||||||
v.dim_sizes().get_typed_unchecked(
|
v.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(dim, false),
|
&llvm_usize.const_int(dim, false),
|
||||||
|
@ -2813,7 +2813,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
);
|
);
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
let ndarray_num_dims = ctx
|
let ndarray_num_dims = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -2824,7 +2824,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let v_dims_src_ptr = unsafe {
|
let v_dims_src_ptr = unsafe {
|
||||||
v.dim_sizes().ptr_offset_unchecked(
|
v.shape().ptr_offset_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
&llvm_usize.const_int(1, false),
|
||||||
|
@ -2833,7 +2833,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
};
|
};
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
ctx,
|
ctx,
|
||||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
ndarray.shape().base_ptr(ctx, generator),
|
||||||
v_dims_src_ptr,
|
v_dims_src_ptr,
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||||
|
@ -2845,7 +2845,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
(None, None),
|
(None, None),
|
||||||
);
|
);
|
||||||
let ndarray_num_elems = ctx
|
let ndarray_num_elems = ctx
|
||||||
|
|
|
@ -103,7 +103,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
});
|
});
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
let ndarray_dims = ndarray.dim_sizes();
|
let ndarray_dims = ndarray.shape();
|
||||||
|
|
||||||
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ where
|
||||||
});
|
});
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
let ndarray_dims = ndarray.dim_sizes();
|
let ndarray_dims = ndarray.shape();
|
||||||
|
|
||||||
let index = ctx
|
let index = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -259,8 +259,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||||
(
|
(
|
||||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -298,9 +298,9 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||||
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
|
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
|
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||||
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||||
|
@ -362,7 +362,7 @@ pub fn call_ndarray_calc_broadcast_index<
|
||||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||||
|
|
||||||
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
|
let array_dims = array.shape().base_ptr(ctx, generator);
|
||||||
let array_ndims = array.load_ndims(ctx);
|
let array_ndims = array.load_ndims(ctx);
|
||||||
let broadcast_idx_ptr = unsafe {
|
let broadcast_idx_ptr = unsafe {
|
||||||
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
|
|
@ -128,7 +128,7 @@ where
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
// Copy the dimension sizes from shape to ndarray.dims
|
// Copy the dimension sizes from shape to ndarray.dims
|
||||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
@ -144,7 +144,7 @@ where
|
||||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
let ndarray_pdim =
|
let ndarray_pdim =
|
||||||
unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) };
|
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
|
||||||
|
|
||||||
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
||||||
|
|
||||||
|
@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
for (i, &shape_dim) in shape.iter().enumerate() {
|
for (i, &shape_dim) in shape.iter().enumerate() {
|
||||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
let ndarray_dim = unsafe {
|
let ndarray_dim = unsafe {
|
||||||
ndarray.dim_sizes().ptr_offset_unchecked(
|
ndarray.shape().ptr_offset_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(i as u64, true),
|
&llvm_usize.const_int(i as u64, true),
|
||||||
|
@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
(None, None),
|
(None, None),
|
||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
@ -380,7 +380,7 @@ where
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
(None, None),
|
(None, None),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -739,7 +739,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let stride = call_ndarray_calc_size(
|
let stride = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&dst_arr.dim_sizes(),
|
&dst_arr.shape(),
|
||||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1155,7 +1155,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let stride = call_ndarray_calc_size(
|
let stride = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&src_arr.dim_sizes(),
|
&src_arr.shape(),
|
||||||
(Some(llvm_usize.const_int(dim, false)), None),
|
(Some(llvm_usize.const_int(dim, false)), None),
|
||||||
);
|
);
|
||||||
let stride =
|
let stride =
|
||||||
|
@ -1173,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let src_stride = call_ndarray_calc_size(
|
let src_stride = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&src_arr.dim_sizes(),
|
&src_arr.shape(),
|
||||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||||
);
|
);
|
||||||
let dst_stride = call_ndarray_calc_size(
|
let dst_stride = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&dst_arr.dim_sizes(),
|
&dst_arr.shape(),
|
||||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1278,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
&this,
|
&this,
|
||||||
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|
||||||
|generator, ctx, shape, idx| unsafe {
|
|generator, ctx, shape, idx| unsafe {
|
||||||
Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
|
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||||
},
|
},
|
||||||
)?
|
)?
|
||||||
} else {
|
} else {
|
||||||
|
@ -1286,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
||||||
|
|
||||||
let ndims = this.load_ndims(ctx);
|
let ndims = this.load_ndims(ctx);
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
|
ndarray.create_shape(ctx, llvm_usize, ndims);
|
||||||
|
|
||||||
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
||||||
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
||||||
|
@ -1318,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
|
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
ndarray.dim_sizes().set_typed_unchecked(
|
ndarray.shape().set_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(i as u64, false),
|
&llvm_usize.const_int(i as u64, false),
|
||||||
|
@ -1336,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
(this.load_ndims(ctx), false),
|
(this.load_ndims(ctx), false),
|
||||||
|generator, ctx, _, idx| {
|
|generator, ctx, _, idx| {
|
||||||
unsafe {
|
unsafe {
|
||||||
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
|
let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
|
||||||
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1397,7 +1397,7 @@ where
|
||||||
&operand,
|
&operand,
|
||||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
||||||
|generator, ctx, v, idx| unsafe {
|
|generator, ctx, v, idx| unsafe {
|
||||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1510,7 +1510,7 @@ where
|
||||||
&ndarray,
|
&ndarray,
|
||||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
||||||
|generator, ctx, v, idx| unsafe {
|
|generator, ctx, v, idx| unsafe {
|
||||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1571,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
if let Some(res) = res {
|
if let Some(res) = res {
|
||||||
let res_ndims = res.load_ndims(ctx);
|
let res_ndims = res.load_ndims(ctx);
|
||||||
let res_dim0 = unsafe {
|
let res_dim0 = unsafe {
|
||||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
};
|
};
|
||||||
let res_dim1 = unsafe {
|
let res_dim1 = unsafe {
|
||||||
res.dim_sizes().get_typed_unchecked(
|
res.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
&llvm_usize.const_int(1, false),
|
||||||
|
@ -1582,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let lhs_dim0 = unsafe {
|
let lhs_dim0 = unsafe {
|
||||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
};
|
};
|
||||||
let rhs_dim1 = unsafe {
|
let rhs_dim1 = unsafe {
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
rhs.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
&llvm_usize.const_int(1, false),
|
||||||
|
@ -1634,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let lhs_dim1 = unsafe {
|
let lhs_dim1 = unsafe {
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
let rhs_dim0 = unsafe {
|
let rhs_dim0 = unsafe {
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
};
|
};
|
||||||
|
|
||||||
// lhs.dims[1] == rhs.dims[0]
|
// lhs.dims[1] == rhs.dims[0]
|
||||||
|
@ -1681,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
},
|
},
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
Ok(Some(unsafe {
|
Ok(Some(unsafe {
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
lhs.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_zero(),
|
&llvm_usize.const_zero(),
|
||||||
|
@ -1691,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
},
|
},
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
Ok(Some(unsafe {
|
Ok(Some(unsafe {
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
rhs.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
&llvm_usize.const_int(1, false),
|
||||||
|
@ -1718,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
let common_dim = {
|
let common_dim = {
|
||||||
let lhs_idx1 = unsafe {
|
let lhs_idx1 = unsafe {
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
lhs.shape().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
&llvm_usize.const_int(1, false),
|
||||||
|
@ -1726,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let rhs_idx0 = unsafe {
|
let rhs_idx0 = unsafe {
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
};
|
};
|
||||||
|
|
||||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||||
|
@ -2146,7 +2141,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
// Dimensions are reversed in the transposed array
|
||||||
let out = create_ndarray_dyn_shape(
|
let out = create_ndarray_dyn_shape(
|
||||||
|
@ -2161,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.builder
|
.builder
|
||||||
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -2198,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let dim = unsafe {
|
let dim = unsafe {
|
||||||
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
let rem_idx_val =
|
let rem_idx_val =
|
||||||
|
@ -2266,7 +2261,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
@ -2494,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
);
|
);
|
||||||
|
|
||||||
// The new shape must be compatible with the old shape
|
// The new shape must be compatible with the old shape
|
||||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
|
@ -2556,8 +2551,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
||||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
|
use inkwell::context::ContextRef;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::{AsContextRef, Context},
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::IntValue,
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
};
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::ProxyType;
|
use super::{
|
||||||
|
structure::{FieldIndexCounter, StructField, StructFields},
|
||||||
|
ProxyType,
|
||||||
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
||||||
{CodeGenContext, CodeGenerator},
|
{CodeGenContext, CodeGenerator},
|
||||||
|
@ -19,6 +24,38 @@ pub struct NDArrayType<'ctx> {
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
||||||
|
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
|
||||||
|
let mut counter = FieldIndexCounter::default();
|
||||||
|
|
||||||
|
NDArrayStructFields {
|
||||||
|
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||||
|
shape: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"shape",
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
data: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"data",
|
||||||
|
ctx.i8_type().ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
|
vec![self.ndims.into(), self.shape.into(), self.data.into()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> NDArrayType<'ctx> {
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||||
pub fn is_representable(
|
pub fn is_representable(
|
||||||
|
@ -86,19 +123,34 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
||||||
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_fields(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> NDArrayStructFields<'ctx> {
|
||||||
|
Self::fields(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||||
//
|
//
|
||||||
// * num_dims: Number of dimensions in the array
|
// * data : Pointer to an array containing the array data
|
||||||
// * dims: Pointer to an array containing the size of each dimension
|
// * itemsize: The size of each NDArray elements in bytes
|
||||||
// * data: Pointer to an array containing the array data
|
// * ndims : Number of dimensions in the array
|
||||||
let field_tys = [
|
// * shape : Pointer to an array containing the shape of the NDArray
|
||||||
llvm_usize.into(),
|
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
||||||
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
let field_tys =
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()).into(),
|
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
];
|
|
||||||
|
|
||||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,18 +50,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
self.get_type()
|
||||||
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.ndims
|
||||||
unsafe {
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the number of dimensions `ndims` into this instance.
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
|
@ -83,59 +75,43 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
/// on the field.
|
/// `getelementptr` on the field.
|
||||||
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
self.get_type()
|
||||||
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.shape
|
||||||
unsafe {
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
|
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||||
pub fn create_dim_sizes(
|
pub fn create_shape(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
size: IntValue<'ctx>,
|
size: IntValue<'ctx>,
|
||||||
) {
|
) {
|
||||||
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
|
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
|
||||||
NDArrayDimsProxy(self)
|
NDArrayShapeProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
self.get_type()
|
||||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.data
|
||||||
unsafe {
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of data elements `data` into this instance.
|
/// Stores the array of data elements `data` into this instance.
|
||||||
|
@ -194,15 +170,15 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
|
||||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
fn element_type<G: CodeGenerator + ?Sized>(
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
generator: &G,
|
||||||
) -> AnyTypeEnum<'ctx> {
|
) -> AnyTypeEnum<'ctx> {
|
||||||
self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type()
|
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -213,7 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
|
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -227,7 +203,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
|
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -266,10 +242,10 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {}
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
|
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
fn downcast_to_type(
|
fn downcast_to_type(
|
||||||
&self,
|
&self,
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -279,7 +255,7 @@ impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ct
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {
|
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
fn upcast_from_type(
|
fn upcast_from_type(
|
||||||
&self,
|
&self,
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -491,7 +467,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||||
let (dim_idx, dim_sz) = unsafe {
|
let (dim_idx, dim_sz) = unsafe {
|
||||||
(
|
(
|
||||||
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
||||||
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None),
|
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let dim_idx = ctx
|
let dim_idx = ctx
|
||||||
|
|
Loading…
Reference in New Issue