1
0
forked from M-Labs/nac3

nac3core/codegen: list comprehension support

This commit is contained in:
pca006132 2021-10-24 16:53:04 +08:00
parent 45673b0ecc
commit 558c3f03ef

View File

@ -410,6 +410,141 @@ pub fn allocate_list<'ctx, 'a>(
} }
} }
pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
expr: &Expr<Option<Type>>,
) -> BasicValueEnum<'ctx> {
if let ExprKind::ListComp { elt, generators } = &expr.node {
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = ctx.ctx.append_basic_block(current, "test");
let body_bb = ctx.ctx.append_basic_block(current, "body");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
let Comprehension { target, iter, ifs, .. } = &generators[0];
let iter_val = generator.gen_expr(ctx, iter).unwrap();
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let index = generator.gen_var_alloc(ctx, ctx.primitives.int32);
// counter = -1
ctx.builder.build_store(index, ctx.ctx.i32_type().const_zero());
let elem_ty = ctx.get_llvm_type(elt.custom.unwrap());
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
let list;
let list_content;
if is_range {
let iter_val = iter_val.into_pointer_value();
let (start, end, step) = destructure_range(ctx, iter_val);
let diff = ctx.builder.build_int_sub(end, start, "diff");
// add 1 to the length as the value is rounded to zero
// the length may be 1 more than the actual length if the division is exact, but the
// length is a upper bound only anyway so it does not matter.
let length = ctx.builder.build_int_signed_div(diff, step, "div");
let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1");
// in case length is non-positive
let is_valid =
ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, length, zero, "check");
let normal = ctx.ctx.append_basic_block(current, "normal_list");
let empty = ctx.ctx.append_basic_block(current, "empty_list");
let list_init = ctx.ctx.append_basic_block(current, "list_init");
ctx.builder.build_conditional_branch(is_valid, normal, empty);
// normal: allocate a list
ctx.builder.position_at_end(normal);
let list_a = allocate_list(ctx, elem_ty, length);
ctx.builder.build_unconditional_branch(list_init);
ctx.builder.position_at_end(empty);
let list_b = allocate_list(ctx, elem_ty, zero);
ctx.builder.build_unconditional_branch(list_init);
ctx.builder.position_at_end(list_init);
let phi = ctx.builder.build_phi(list_a.get_type(), "phi");
phi.add_incoming(&[(&list_a, normal), (&list_b, empty)]);
list = phi.as_basic_value().into_pointer_value();
list_content = ctx
.build_gep_and_load(list, &[zero, int32.const_int(1, false)])
.into_pointer_value();
let i = generator.gen_store_target(ctx, target);
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init"));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb);
let sign = ctx.builder.build_int_compare(
inkwell::IntPredicate::SGT,
step,
int32.const_zero(),
"sign",
);
// add and test
let tmp = ctx.builder.build_int_add(
ctx.builder.build_load(i, "i").into_int_value(),
step,
"start_loop",
);
ctx.builder.build_store(i, tmp);
// // if step > 0, continue when i < end
let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1");
// if step < 0, continue when i > end
let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2");
let pos = ctx.builder.build_and(sign, cmp1, "pos");
let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg");
ctx.builder.build_conditional_branch(
ctx.builder.build_or(pos, neg, "or"),
body_bb,
cont_bb,
);
ctx.builder.position_at_end(body_bb);
} else {
let length = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero])
.into_int_value();
list = allocate_list(ctx, elem_ty, length);
list_content = ctx
.build_gep_and_load(list, &[zero, int32.const_int(1, false)])
.into_pointer_value();
let counter = generator.gen_var_alloc(ctx, ctx.primitives.int32);
// counter = -1
ctx.builder.build_store(counter, ctx.ctx.i32_type().const_int(u64::max_value(), true));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb);
let tmp = ctx.builder.build_load(counter, "i").into_int_value();
let tmp = ctx.builder.build_int_add(tmp, int32.const_int(1, false), "inc");
ctx.builder.build_store(counter, tmp);
let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, length, "cmp");
ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb);
ctx.builder.position_at_end(body_bb);
let arr_ptr = ctx
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
)
.into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp]);
generator.gen_assign(ctx, target, val);
}
for cond in ifs.iter() {
let result = generator.gen_expr(ctx, cond).unwrap().into_int_value();
let succ = ctx.ctx.append_basic_block(current, "then");
ctx.builder.build_conditional_branch(result, succ, test_bb);
ctx.builder.position_at_end(succ);
}
let elem = generator.gen_expr(ctx, elt).unwrap();
let i = ctx.builder.build_load(index, "i").into_int_value();
let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") };
ctx.builder.build_store(elem_ptr, elem);
ctx.builder
.build_store(index, ctx.builder.build_int_add(i, int32.const_int(1, false), "inc"));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(cont_bb);
let len_ptr = unsafe { ctx.builder.build_gep(list, &[zero, zero], "length") };
ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index"));
list.into()
} else {
unreachable!()
}
}
pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -731,6 +866,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>(
ctx.build_gep_and_load(v, &[int32.const_zero(), index]) ctx.build_gep_and_load(v, &[int32.const_zero(), index])
} }
} }
ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr),
_ => unimplemented!(), _ => unimplemented!(),
}) })
} }