diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index fa1c2331..a5f35cc7 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -410,6 +410,141 @@ pub fn allocate_list<'ctx, 'a>( } } +pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + expr: &Expr>, +) -> BasicValueEnum<'ctx> { + if let ExprKind::ListComp { elt, generators } = &expr.node { + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "test"); + let body_bb = ctx.ctx.append_basic_block(current, "body"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + + let Comprehension { target, iter, ifs, .. } = &generators[0]; + let iter_val = generator.gen_expr(ctx, iter).unwrap(); + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + + let index = generator.gen_var_alloc(ctx, ctx.primitives.int32); + // counter = -1 + ctx.builder.build_store(index, ctx.ctx.i32_type().const_zero()); + + let elem_ty = ctx.get_llvm_type(elt.custom.unwrap()); + let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); + let list; + let list_content; + + if is_range { + let iter_val = iter_val.into_pointer_value(); + let (start, end, step) = destructure_range(ctx, iter_val); + let diff = ctx.builder.build_int_sub(end, start, "diff"); + // add 1 to the length as the value is rounded to zero + // the length may be 1 more than the actual length if the division is exact, but the + // length is a upper bound only anyway so it does not matter. + let length = ctx.builder.build_int_signed_div(diff, step, "div"); + let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); + // in case length is non-positive + let is_valid = + ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, length, zero, "check"); + let normal = ctx.ctx.append_basic_block(current, "normal_list"); + let empty = ctx.ctx.append_basic_block(current, "empty_list"); + let list_init = ctx.ctx.append_basic_block(current, "list_init"); + ctx.builder.build_conditional_branch(is_valid, normal, empty); + // normal: allocate a list + ctx.builder.position_at_end(normal); + let list_a = allocate_list(ctx, elem_ty, length); + ctx.builder.build_unconditional_branch(list_init); + ctx.builder.position_at_end(empty); + let list_b = allocate_list(ctx, elem_ty, zero); + ctx.builder.build_unconditional_branch(list_init); + ctx.builder.position_at_end(list_init); + let phi = ctx.builder.build_phi(list_a.get_type(), "phi"); + phi.add_incoming(&[(&list_a, normal), (&list_b, empty)]); + list = phi.as_basic_value().into_pointer_value(); + list_content = ctx + .build_gep_and_load(list, &[zero, int32.const_int(1, false)]) + .into_pointer_value(); + + let i = generator.gen_store_target(ctx, target); + ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let sign = ctx.builder.build_int_compare( + inkwell::IntPredicate::SGT, + step, + int32.const_zero(), + "sign", + ); + // add and test + let tmp = ctx.builder.build_int_add( + ctx.builder.build_load(i, "i").into_int_value(), + step, + "start_loop", + ); + ctx.builder.build_store(i, tmp); + // // if step > 0, continue when i < end + let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1"); + // if step < 0, continue when i > end + let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2"); + let pos = ctx.builder.build_and(sign, cmp1, "pos"); + let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg"); + ctx.builder.build_conditional_branch( + ctx.builder.build_or(pos, neg, "or"), + body_bb, + cont_bb, + ); + ctx.builder.position_at_end(body_bb); + } else { + let length = ctx + .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero]) + .into_int_value(); + list = allocate_list(ctx, elem_ty, length); + list_content = ctx + .build_gep_and_load(list, &[zero, int32.const_int(1, false)]) + .into_pointer_value(); + let counter = generator.gen_var_alloc(ctx, ctx.primitives.int32); + // counter = -1 + ctx.builder.build_store(counter, ctx.ctx.i32_type().const_int(u64::max_value(), true)); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let tmp = ctx.builder.build_load(counter, "i").into_int_value(); + let tmp = ctx.builder.build_int_add(tmp, int32.const_int(1, false), "inc"); + ctx.builder.build_store(counter, tmp); + let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, length, "cmp"); + ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb); + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero, int32.const_int(1, false)], + ) + .into_pointer_value(); + let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); + generator.gen_assign(ctx, target, val); + } + for cond in ifs.iter() { + let result = generator.gen_expr(ctx, cond).unwrap().into_int_value(); + let succ = ctx.ctx.append_basic_block(current, "then"); + ctx.builder.build_conditional_branch(result, succ, test_bb); + ctx.builder.position_at_end(succ); + } + let elem = generator.gen_expr(ctx, elt).unwrap(); + let i = ctx.builder.build_load(index, "i").into_int_value(); + let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }; + ctx.builder.build_store(elem_ptr, elem); + ctx.builder + .build_store(index, ctx.builder.build_int_add(i, int32.const_int(1, false), "inc")); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(cont_bb); + let len_ptr = unsafe { ctx.builder.build_gep(list, &[zero, zero], "length") }; + ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); + list.into() + } else { + unreachable!() + } +} + pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -731,6 +866,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( ctx.build_gep_and_load(v, &[int32.const_zero(), index]) } } + ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr), _ => unimplemented!(), }) }