core/stmt: Add gen_for_range_callback

For generating for loops over range objects or array slices.
This commit is contained in:
David Mak 2024-05-27 15:11:14 +08:00
parent 82cc693b11
commit 588c15f80d
2 changed files with 150 additions and 0 deletions

View File

@ -658,6 +658,17 @@ impl<'ctx> RangeValue<'ctx> {
RangeValue(ptr, name) RangeValue(ptr, name)
} }
/// Returns the element type of this `range` object.
#[must_use]
pub fn element_type(&self) -> IntType<'ctx> {
self.as_ptr_value()
.get_type()
.get_element_type()
.into_array_type()
.get_element_type()
.into_int_type()
}
/// Returns the underlying [`PointerValue`] pointing to the `range` instance. /// Returns the underlying [`PointerValue`] pointing to the `range` instance.
#[must_use] #[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> { pub fn as_ptr_value(&self) -> PointerValue<'ctx> {

View File

@ -621,6 +621,145 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
) )
} }
/// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following
/// C code:
///
/// ```c
/// bool incr = start_fn() <= end_fn();
/// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) {
/// body_fn(i);
/// }
/// ```
///
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// iterable.
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body.
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
is_unsigned: bool,
start_fn: StartFn,
(stop_fn, stop_inclusive): (StopFn, bool),
step_fn: StepFn,
body_fn: BodyFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = start_fn(generator, ctx)
.map(IntValue::get_type)
.unwrap();
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
let start = start_fn(generator, ctx)?;
ctx.builder.build_store(i_addr, start).unwrap();
let start = start_fn(generator, ctx)?;
let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
}
} else {
stop
};
let incr = ctx.builder.build_int_compare(
if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE },
start,
stop,
"",
).unwrap();
Ok((i_addr, incr))
},
|generator, ctx, (i_addr, incr)| {
let (lt_cmp_op, gt_cmp_op) = match (is_unsigned, stop_inclusive) {
(true, true) => (IntPredicate::ULE, IntPredicate::UGE),
(true, false) => (IntPredicate::ULT, IntPredicate::UGT),
(false, true) => (IntPredicate::SLE, IntPredicate::SGE),
(false, false) => (IntPredicate::SLT, IntPredicate::SGT),
};
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
}
} else {
stop
};
let i_lt_end = ctx.builder
.build_int_compare(lt_cmp_op, i, stop, "")
.unwrap();
let i_gt_end = ctx.builder
.build_int_compare(gt_cmp_op, i, stop, "")
.unwrap();
let cond = ctx.builder
.build_select(incr, i_lt_end, i_gt_end, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(cond)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body_fn(generator, ctx, i)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = step_fn(generator, ctx)?;
let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() {
if is_unsigned {
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
}
} else {
incr_val
};
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)
}
/// See [`CodeGenerator::gen_while`]. /// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>( pub fn gen_while<G: CodeGenerator>(
generator: &mut G, generator: &mut G,