From 588c15f80d63445f9f45815d8bb2cd9505f1381e Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 May 2024 15:11:14 +0800 Subject: [PATCH] core/stmt: Add gen_for_range_callback For generating for loops over range objects or array slices. --- nac3core/src/codegen/classes.rs | 11 +++ nac3core/src/codegen/stmt.rs | 139 ++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 3b87bbbe5..b3b6da43d 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -658,6 +658,17 @@ impl<'ctx> RangeValue<'ctx> { RangeValue(ptr, name) } + /// Returns the element type of this `range` object. + #[must_use] + pub fn element_type(&self) -> IntType<'ctx> { + self.as_ptr_value() + .get_type() + .get_element_type() + .into_array_type() + .get_element_type() + .into_int_type() + } + /// Returns the underlying [`PointerValue`] pointing to the `range` instance. #[must_use] pub fn as_ptr_value(&self) -> PointerValue<'ctx> { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index ae0e00984..298982084 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -621,6 +621,145 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( ) } +/// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following +/// C code: +/// +/// ```c +/// bool incr = start_fn() <= end_fn(); +/// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) { +/// body_fn(i); +/// } +/// ``` +/// +/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned. +/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like +/// iterable. +/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like +/// iterable. This value will be extended to the size of `start`. +/// - `stop_inclusive`: Whether the stop value should be treated as inclusive. +/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like +/// iterable. This value will be extended to the size of `start`. +/// - `body_fn`: A lambda of IR statements within the loop body. +pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + is_unsigned: bool, + start_fn: StartFn, + (stop_fn, stop_inclusive): (StopFn, bool), + step_fn: StepFn, + body_fn: BodyFn, +) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, +{ + let init_val_t = start_fn(generator, ctx) + .map(IntValue::get_type) + .unwrap(); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; + + let start = start_fn(generator, ctx)?; + ctx.builder.build_store(i_addr, start).unwrap(); + + let start = start_fn(generator, ctx)?; + let stop = stop_fn(generator, ctx)?; + let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() { + if is_unsigned { + ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap() + } + } else { + stop + }; + + let incr = ctx.builder.build_int_compare( + if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE }, + start, + stop, + "", + ).unwrap(); + + Ok((i_addr, incr)) + }, + |generator, ctx, (i_addr, incr)| { + let (lt_cmp_op, gt_cmp_op) = match (is_unsigned, stop_inclusive) { + (true, true) => (IntPredicate::ULE, IntPredicate::UGE), + (true, false) => (IntPredicate::ULT, IntPredicate::UGT), + (false, true) => (IntPredicate::SLE, IntPredicate::SGE), + (false, false) => (IntPredicate::SLT, IntPredicate::SGT), + }; + + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let stop = stop_fn(generator, ctx)?; + let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() { + if is_unsigned { + ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap() + } + } else { + stop + }; + + let i_lt_end = ctx.builder + .build_int_compare(lt_cmp_op, i, stop, "") + .unwrap(); + let i_gt_end = ctx.builder + .build_int_compare(gt_cmp_op, i, stop, "") + .unwrap(); + + let cond = ctx.builder + .build_select(incr, i_lt_end, i_gt_end, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + Ok(cond) + }, + |generator, ctx, (i_addr, _)| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + body_fn(generator, ctx, i) + }, + |generator, ctx, (i_addr, _)| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + let incr_val = step_fn(generator, ctx)?; + let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() { + if is_unsigned { + ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap() + } else { + ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap() + } + } else { + incr_val + }; + + let i = ctx.builder.build_int_add(i, incr_val, "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); + + Ok(()) + }, + ) +} + /// See [`CodeGenerator::gen_while`]. pub fn gen_while( generator: &mut G,