diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 9c7c262..ff492c7 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -8,7 +8,7 @@ use crate::codegen::{ CodeGenerator, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const}, llvm_intrinsics::call_int_umin, - stmt::gen_for_callback, + stmt::gen_for_callback_incrementing, }; #[cfg(not(debug_assertions))] @@ -940,30 +940,15 @@ impl<'ctx> NDArrayDataProxy<'ctx> { ctx.current_loc, ); - gen_for_callback( + let indices_len = indices.load_size(ctx, None); + let ndarray_len = self.0.load_ndims(ctx); + let len = call_int_umin(ctx, indices_len, ndarray_len, None); + gen_for_callback_incrementing( generator, ctx, - |generator, ctx| { - let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); - - Ok(i) - }, - |_, ctx, i_addr| { - let indices_len = indices.load_size(ctx, None); - let ndarray_len = self.0.load_ndims(ctx); - - let len = call_int_umin(ctx, indices_len, ndarray_len, None); - - let i = ctx.builder.build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "").unwrap()) - }, - |generator, ctx, i_addr| { - let i = ctx.builder.build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + llvm_usize.const_zero(), + (len, false), + |generator, ctx, i| { let (dim_idx, dim_sz) = unsafe { ( indices.data().get_unchecked(ctx, i, None).into_int_value(), @@ -989,16 +974,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> { Ok(()) }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); - ctx.builder.build_store(i_addr, i).unwrap(); - - Ok(()) - }, + llvm_usize.const_int(1, false), ).unwrap(); unsafe { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3f227fe..9c670e0 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -14,7 +14,7 @@ use crate::{ call_ndarray_calc_size, }, llvm_intrinsics::call_memcpy_generic, - stmt::gen_for_callback + stmt::gen_for_callback_incrementing, }, symbol_resolver::ValueEnum, toplevel::{ @@ -52,30 +52,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( assert!(llvm_ndarray_data_t.is_sized()); // Assert that all dimensions are non-negative - gen_for_callback( + let shape_len = shape_len_fn(generator, ctx, shape)?; + gen_for_callback_incrementing( generator, ctx, - |generator, ctx| { - let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); - - Ok(i) - }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let shape_len = shape_len_fn(generator, ctx, shape)?; - debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap()) - }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + llvm_usize.const_zero(), + (shape_len, false), + |generator, ctx, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); @@ -94,16 +77,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( Ok(()) }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); - ctx.builder.build_store(i_addr, i).unwrap(); - - Ok(()) - }, + llvm_usize.const_int(1, false), )?; let ndarray = generator.gen_var_alloc( @@ -120,30 +94,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); // Copy the dimension sizes from shape to ndarray.dims - gen_for_callback( + let shape_len = shape_len_fn(generator, ctx, shape)?; + gen_for_callback_incrementing( generator, ctx, - |generator, ctx| { - let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); - - Ok(i) - }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let shape_len = shape_len_fn(generator, ctx, shape)?; - debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap()) - }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + llvm_usize.const_zero(), + (shape_len, false), + |generator, ctx, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); let shape_dim = ctx.builder @@ -158,16 +115,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( Ok(()) }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); - ctx.builder.build_store(i_addr, i).unwrap(); - - Ok(()) - }, + llvm_usize.const_int(1, false), )?; let ndarray_num_elems = call_ndarray_calc_size( @@ -342,28 +290,12 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( ndarray.dim_sizes().as_ptr_value(ctx), ); - gen_for_callback( + gen_for_callback_incrementing( generator, ctx, - |generator, ctx| { - let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); - - Ok(i) - }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "").unwrap()) - }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + llvm_usize.const_zero(), + (ndarray_num_elems, false), + |generator, ctx, i| { let elem = unsafe { ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None) }; @@ -373,16 +305,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( Ok(()) }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); - ctx.builder.build_store(i_addr, i).unwrap(); - - Ok(()) - }, + llvm_usize.const_int(1, false), ) } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 90eab5b..0983163 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -536,6 +536,84 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>( Ok(()) } +/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the +/// following C code: +/// +/// ```c +/// for (int x = init_val; x /* < or <= ; see `max_val` */ max_val; x += incr_val) { +/// body(x); +/// } +/// ``` +/// +/// * `init_val` - The initial value of the loop variable. The type of this value will also be used +/// as the type of the loop variable. +/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum +/// value should be treated as inclusive (as opposed to exclusive). +/// * `body` - A lambda containing IR statements within the loop body. +/// * `incr_val` - The value to increment the loop variable on each iteration. +pub fn gen_for_callback_incrementing<'ctx, 'a, BodyFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + init_val: IntValue<'ctx>, + max_val: (IntValue<'ctx>, bool), + body: BodyFn, + incr_val: IntValue<'ctx>, +) -> Result<(), String> + where + BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, +{ + let init_val_t = init_val.get_type(); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; + ctx.builder.build_store(i_addr, init_val).unwrap(); + + Ok(i_addr) + }, + |_, ctx, i_addr| { + let cmp_op = if max_val.1 { + IntPredicate::ULE + } else { + IntPredicate::ULT + }; + + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let max_val = ctx.builder + .build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "") + .unwrap(); + + Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) + }, + |generator, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + body(generator, ctx, i) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let incr_val = ctx.builder + .build_int_z_extend_or_bit_cast(incr_val, init_val_t, "") + .unwrap(); + let i = ctx.builder.build_int_add(i, incr_val, "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); + + Ok(()) + }, + ) +} + /// See [`CodeGenerator::gen_while`]. pub fn gen_while( generator: &mut G,