1
0
forked from M-Labs/nac3

core/numpy: Add more helper functions

This commit is contained in:
David Mak 2024-05-29 14:19:12 +08:00
parent b6ff75dcaf
commit 2cf79510c2
2 changed files with 58 additions and 55 deletions

View File

@ -33,6 +33,30 @@ use crate::{
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
}; };
/// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None))
}
/// Creates an `NDArray` instance from a dynamic shape. /// Creates an `NDArray` instance from a dynamic shape.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -52,15 +76,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>, LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>, DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
{ {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative // Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing( gen_for_callback_incrementing(
@ -92,12 +109,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let ndarray = generator.gen_var_alloc( let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = shape_len_fn(generator, ctx, shape)?; let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
@ -130,13 +142,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let ndarray_num_elems = call_ndarray_calc_size( let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray) Ok(ndarray)
} }
@ -151,15 +157,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type, elem_ty: Type,
shape: &[IntValue<'ctx>], shape: &[IntValue<'ctx>],
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for shape_dim in shape { for shape_dim in shape {
let shape_dim_gez = ctx.builder let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "") .build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
@ -177,12 +176,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
// TODO: Disallow dim_sz > u32_MAX // TODO: Disallow dim_sz > u32_MAX
} }
let ndarray = generator.gen_var_alloc( let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = llvm_usize.const_int(shape.len() as u64, false); let num_dims = llvm_usize.const_int(shape.len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
@ -200,6 +194,21 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap(); ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
} }
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray)
}
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
ndarray: NDArrayValue<'ctx>,
) -> NDArrayValue<'ctx> {
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
@ -208,7 +217,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
); );
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray) ndarray
} }
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(

View File

@ -671,14 +671,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
let start = start_fn(generator, ctx)?; let start = start_fn(generator, ctx)?;
let stop = stop_fn(generator, ctx)?; let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() { let stop = if stop.get_type().get_bit_width() == start.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
}
} else {
stop stop
} else if is_unsigned {
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
}; };
let incr = ctx.builder.build_int_compare( let incr = ctx.builder.build_int_compare(
@ -703,14 +701,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let stop = stop_fn(generator, ctx)?; let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() { let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
}
} else {
stop stop
} else if is_unsigned {
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
}; };
let i_lt_end = ctx.builder let i_lt_end = ctx.builder
@ -742,14 +738,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
.unwrap(); .unwrap();
let incr_val = step_fn(generator, ctx)?; let incr_val = step_fn(generator, ctx)?;
let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() { let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
}
} else {
incr_val incr_val
} else if is_unsigned {
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
}; };
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap(); let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();