refactor: merge gen_for_enumerate_list with gen_for_enumerate_tuple

This commit is contained in:
rclovis
2026-01-30 16:11:08 +01:00
parent cfee88277d
commit ce3c24485a

View File

@@ -659,130 +659,75 @@ where
Ok(())
}
/// Generates a `for` statement with `enumerate(tuple)` as its iterable object.
///
/// * `tuple` - The tuple to iterate over.
/// * `element_ty` - The type of the tuple elements, if known.
/// * `length` - The length of the tuple.
/// * `start` - The starting index for enumeration.
/// * `target_i` - The pointer to store the current index.
/// * `body` - The body of the loop.
/// * `orelse` - The `else` block of the loop.
#[allow(clippy::too_many_arguments)]
fn gen_for_enumerate_tuple<'ctx, G: CodeGenerator>(
generator: &mut G,
fn build_tuple_elem_switch<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
tuple: TupleValue<'ctx>,
tuple_len: u64,
element_ty: Option<Type>,
length: u64,
start: IntValue<'ctx>,
target_i: PointerValue<'ctx>,
body: &[Stmt<Option<Type>>],
orelse: &[Stmt<Option<Type>>],
) -> Result<(), String> {
next_i: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let int32 = ctx.i32;
let update_bb = ctx.builder.get_insert_block().unwrap();
let merge_bb = ctx.ctx.insert_basic_block_after(update_bb, "tuple.merge");
let default_element_ty = ctx.get_llvm_type(element_ty.unwrap_or(ctx.primitives.int32));
gen_for_callback(
generator,
ctx,
None,
|_, ctx| {
let element_struct = ctx.ctx.struct_type(&[int32.into(), default_element_ty], false);
let iv_pair = gen_var(ctx, element_struct, Some("for.v.addr"));
let i = ctx.builder.build_struct_gep(iv_pair, 0, "i").unwrap();
ctx.builder.build_store(i, start).unwrap();
if element_ty.is_some() {
let first_v = tuple.extract(ctx, 0);
let v = ctx.builder.build_struct_gep(iv_pair, 1, "v").unwrap();
ctx.builder.build_store(v, first_v).unwrap();
}
Ok(iv_pair)
},
|_, ctx, iv_pair| {
let i = ctx.builder.build_struct_gep(iv_pair, 0, "i").unwrap();
let i_val =
ctx.builder.build_load(i, "i_val").map(BasicValueEnum::into_int_value).unwrap();
Ok(gen_in_range_check(
ctx,
ctx.builder.build_int_sub(i_val, start, "sub").unwrap(),
int32.const_int(length, false),
int32.const_int(1, false),
))
},
|generator, ctx, _, iv_pair| {
ctx.builder
.build_store(target_i, ctx.builder.build_load(iv_pair, "iv").unwrap())
.unwrap();
generator.gen_block(ctx, body.iter())?;
Ok(())
},
|_, ctx, iv_pair| {
let update_bb = ctx.builder.get_insert_block().unwrap();
if element_ty.is_some() {
let merge_bb = ctx.ctx.insert_basic_block_after(update_bb, "tuple.merge");
let mut tmp_bb = update_bb;
let mut cases = Vec::new();
for idx in 0..length {
let case_bb =
ctx.ctx.insert_basic_block_after(tmp_bb, &format!("tuple.case.{idx}"));
cases.push((int32.const_int(idx, false), case_bb));
tmp_bb = case_bb;
}
let i = ctx.builder.build_struct_gep(iv_pair, 0, "i").unwrap();
let i_val =
ctx.builder.build_load(i, "i_val").map(BasicValueEnum::into_int_value).unwrap();
let next_i =
ctx.builder.build_int_add(i_val, int32.const_int(1, false), "inc").unwrap();
ctx.builder.build_store(i, next_i).unwrap();
ctx.builder
.build_switch(
ctx.builder.build_int_sub(next_i, start, "sub").unwrap(),
merge_bb,
&cases,
)
.unwrap();
ctx.builder.position_at_end(merge_bb);
let phi = ctx.builder.build_phi(default_element_ty, "tuple.elem.phi").unwrap();
for (idx, (_, case_bb)) in cases.iter().take(length as usize).enumerate() {
ctx.builder.position_at_end(*case_bb);
let elem_val = tuple.extract(ctx, idx as u32);
ctx.builder.build_unconditional_branch(merge_bb).unwrap();
phi.add_incoming(&[(&elem_val, *case_bb)]);
}
ctx.builder.position_at_end(merge_bb);
let default_value = default_element_ty.const_zero();
phi.add_incoming(&[(&default_value, update_bb)]);
let next_v = phi.as_basic_value();
let v = ctx.builder.build_struct_gep(iv_pair, 1, "v").unwrap();
ctx.builder.build_store(v, next_v).unwrap();
}
Ok(())
},
|generator, ctx| generator.gen_block(ctx, orelse.iter()),
)
let mut tmp_bb = update_bb;
let mut cases = Vec::new();
for idx in 0..tuple_len {
let case_bb = ctx.ctx.insert_basic_block_after(tmp_bb, &format!("tuple.case.{idx}"));
cases.push((int32.const_int(idx, false), case_bb));
tmp_bb = case_bb;
}
ctx.builder
.build_switch(ctx.builder.build_int_sub(next_i, start, "sub").unwrap(), merge_bb, &cases)
.unwrap();
ctx.builder.position_at_end(merge_bb);
let phi = ctx.builder.build_phi(default_element_ty, "tuple.elem.phi").unwrap();
for (idx, (_, case_bb)) in cases.iter().take(tuple_len as usize).enumerate() {
ctx.builder.position_at_end(*case_bb);
let elem_val = tuple.extract(ctx, idx as u32);
ctx.builder.build_unconditional_branch(merge_bb).unwrap();
phi.add_incoming(&[(&elem_val, *case_bb)]);
}
ctx.builder.position_at_end(merge_bb);
let default_value = default_element_ty.const_zero();
phi.add_incoming(&[(&default_value, update_bb)]);
phi.as_basic_value()
}
/// Generates a `for` statement with `enumerate(list)` as its iterable object.
/// Generates a `for` statement with `enumerate(iterable)` as its iterable object.
///
/// * `list` - The list to iterate over.
/// * `element_ty` - The type of the list elements, if known.
/// * `length` - The length of the list.
/// * `element_ty` - The type of the iterable elements, if known.
/// * `length` - The length of the iterable.
/// * `start` - The starting index for enumeration.
/// * `target_i` - The pointer to store the current index.
/// * `get_first_elem` - A closure that returns the first element of the iterable.
/// * `get_next_elem` - A closure that returns the next element given the next index.
/// * `body` - The body of the loop.
/// * `orelse` - The `else` block of the loop.
#[allow(clippy::too_many_arguments)]
fn gen_for_enumerate_list<'ctx, G: CodeGenerator>(
fn gen_for_enumerate<'ctx, G, GetFirst, GetNext>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListValue<'ctx>,
element_ty: Option<Type>,
length: IntValue<'ctx>,
start: IntValue<'ctx>,
target_i: PointerValue<'ctx>,
get_first_elem: GetFirst,
get_next_elem: GetNext,
body: &[Stmt<Option<Type>>],
orelse: &[Stmt<Option<Type>>],
) -> Result<(), String> {
) -> Result<(), String>
where
G: CodeGenerator,
GetFirst: Fn(&mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>,
GetNext: Fn(&mut CodeGenContext<'ctx, '_>, IntValue<'ctx>) -> BasicValueEnum<'ctx>,
{
let int32 = ctx.i32;
let default_element_ty = ctx.get_llvm_type(element_ty.unwrap_or(ctx.primitives.int32));
gen_for_callback(
@@ -795,8 +740,7 @@ fn gen_for_enumerate_list<'ctx, G: CodeGenerator>(
let i = ctx.builder.build_struct_gep(iv_pair, 0, "i").unwrap();
ctx.builder.build_store(i, start).unwrap();
if element_ty.is_some() {
let first_v: BasicValueEnum =
list.data(ctx).get_unchecked(ctx, &int32.const_int(0, false), Some("first_v"));
let first_v = get_first_elem(ctx);
let v = ctx.builder.build_struct_gep(iv_pair, 1, "v").unwrap();
ctx.builder.build_store(v, first_v).unwrap();
}
@@ -828,11 +772,7 @@ fn gen_for_enumerate_list<'ctx, G: CodeGenerator>(
ctx.builder.build_int_add(i_val, int32.const_int(1, false), "inc").unwrap();
ctx.builder.build_store(i, next_i).unwrap();
if element_ty.is_some() {
let next_v: BasicValueEnum = list.data(ctx).get_unchecked(
ctx,
&ctx.builder.build_int_sub(next_i, start, "sub").unwrap(),
Some("next_v"),
);
let next_v = get_next_elem(ctx, next_i);
let v = ctx.builder.build_struct_gep(iv_pair, 1, "v").unwrap();
ctx.builder.build_store(v, next_v).unwrap();
}
@@ -986,8 +926,29 @@ pub fn gen_for<G: CodeGenerator>(
let val = arraylike_flatten_element_type(&mut ctx.unifier, iterable_ty);
let element_ty =
if ctx.unifier.is_concrete(val, &[]) { Some(val) } else { None };
gen_for_enumerate_list(
generator, ctx, iterable, element_ty, length, start, target_i, body, orelse,
gen_for_enumerate(
generator,
ctx,
element_ty,
length,
start,
target_i,
|ctx| {
iterable.data(ctx).get_unchecked(
ctx,
&int32.const_int(0, false),
Some("first_v"),
)
},
|ctx, next_i| {
iterable.data(ctx).get_unchecked(
ctx,
&ctx.builder.build_int_sub(next_i, start, "sub").unwrap(),
Some("next_v"),
)
},
body,
orelse,
)?;
}
@@ -995,14 +956,21 @@ pub fn gen_for<G: CodeGenerator>(
let iterable = TupleType::from_unifier_type(ctx, iterable_ty)
.map_value(iterable_val.into_struct_value(), Some("tuple"));
let element_ty = if tuple_tys.is_empty() { None } else { Some(tuple_tys[0]) };
gen_for_enumerate_tuple(
let length = int32.const_int(tuple_tys.len() as u64, false);
let tuple_len = tuple_tys.len() as u64;
gen_for_enumerate(
generator,
ctx,
iterable,
element_ty,
tuple_tys.len() as u64,
length,
start,
target_i,
|ctx| iterable.extract(ctx, 0),
|ctx, next_i| {
build_tuple_elem_switch(
ctx, iterable, tuple_len, element_ty, start, next_i,
)
},
body,
orelse,
)?;