From 2cf79510c29deda0123a62e70a90205fc6995cf0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 29 May 2024 14:19:12 +0800 Subject: [PATCH] core/numpy: Add more helper functions --- nac3core/src/codegen/numpy.rs | 77 +++++++++++++++++++---------------- nac3core/src/codegen/stmt.rs | 36 +++++++--------- 2 files changed, 58 insertions(+), 55 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e44232f..1036d04 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -33,6 +33,30 @@ use crate::{ typecheck::typedef::{FunSignature, Type}, }; +/// Creates an uninitialized `NDArray` instance. +fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, +) -> Result, String> { + let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; + + Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) +} + /// Creates an `NDArray` instance from a dynamic shape. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -52,15 +76,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result, String>, { - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); - assert!(llvm_ndarray_data_t.is_sized()); - // Assert that all dimensions are non-negative let shape_len = shape_len_fn(generator, ctx, shape)?; gen_for_callback_incrementing( @@ -92,12 +109,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( llvm_usize.const_int(1, false), )?; - let ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None, - )?; - let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); + let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let num_dims = shape_len_fn(generator, ctx, shape)?; ndarray.store_ndims(ctx, generator, num_dims); @@ -130,13 +142,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( llvm_usize.const_int(1, false), )?; - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); Ok(ndarray) } @@ -151,15 +157,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( elem_ty: Type, shape: &[IntValue<'ctx>], ) -> Result, String> { - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); - assert!(llvm_ndarray_data_t.is_sized()); - for shape_dim in shape { let shape_dim_gez = ctx.builder .build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "") @@ -177,12 +176,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( // TODO: Disallow dim_sz > u32_MAX } - let ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None, - )?; - let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); + let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let num_dims = llvm_usize.const_int(shape.len() as u64, false); ndarray.store_ndims(ctx, generator, num_dims); @@ -200,6 +194,21 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap(); } + let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); + + Ok(ndarray) +} + +/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields. +fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + ndarray: NDArrayValue<'ctx>, +) -> NDArrayValue<'ctx> { + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, @@ -208,7 +217,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - Ok(ndarray) + ndarray } fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 2989820..e100f31 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -671,14 +671,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( let start = start_fn(generator, ctx)?; let stop = stop_fn(generator, ctx)?; - let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() { - if is_unsigned { - ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap() - } else { - ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap() - } - } else { + let stop = if stop.get_type().get_bit_width() == start.get_type().get_bit_width() { stop + } else if is_unsigned { + ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap() }; let incr = ctx.builder.build_int_compare( @@ -703,14 +701,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( .map(BasicValueEnum::into_int_value) .unwrap(); let stop = stop_fn(generator, ctx)?; - let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() { - if is_unsigned { - ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap() - } else { - ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap() - } - } else { + let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() { stop + } else if is_unsigned { + ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap() }; let i_lt_end = ctx.builder @@ -742,14 +738,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( .unwrap(); let incr_val = step_fn(generator, ctx)?; - let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() { - if is_unsigned { - ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap() - } else { - ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap() - } - } else { + let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() { incr_val + } else if is_unsigned { + ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap() }; let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();