From 894083c6a398ab2c47929fc996c020fba4e018a2 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 1 Aug 2024 12:25:10 +0800 Subject: [PATCH] core/codegen: refactor gen_{for,comprehension} to match on iter type --- nac3core/src/codegen/expr.rs | 188 +++++++++------- nac3core/src/codegen/stmt.rs | 210 ++++++++++-------- nac3core/src/typecheck/type_inferencer/mod.rs | 28 +-- 3 files changed, 235 insertions(+), 191 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 494b2380..f727ad8a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -995,8 +995,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.position_at_end(init_bb); let Comprehension { target, iter, ifs, .. } = &generators[0]; + + let iter_ty = iter.custom.unwrap(); let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { - v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? + v.to_basic_value_enum(ctx, generator, iter_ty)? } else { for bb in [test_bb, body_bb, cont_bb] { ctx.builder.position_at_end(bb); @@ -1014,96 +1016,120 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.build_store(index, zero_size_t).unwrap(); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); - let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); let list; - if is_range { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); - let (start, stop, step) = destructure_range(ctx, iter_val); - let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); - // add 1 to the length as the value is rounded to zero - // the length may be 1 more than the actual length if the division is exact, but the - // length is a upper bound only anyway so it does not matter. - let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); - let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); - // in case length is non-positive - let is_valid = - ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); + match &*ctx.unifier.get_ty(iter_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => + { + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + let (start, stop, step) = destructure_range(ctx, iter_val); + let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); + // add 1 to the length as the value is rounded to zero + // the length may be 1 more than the actual length if the division is exact, but the + // length is a upper bound only anyway so it does not matter. + let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); + let length = + ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); + // in case length is non-positive + let is_valid = + ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); - let list_alloc_size = ctx - .builder - .build_select( - is_valid, - ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(), - zero_size_t, - "listcomp.alloc_size", - ) - .unwrap(); - list = allocate_list( - generator, - ctx, - Some(elem_ty), - list_alloc_size.into_int_value(), - Some("listcomp.addr"), - ); + let list_alloc_size = ctx + .builder + .build_select( + is_valid, + ctx.builder + .build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len") + .unwrap(), + zero_size_t, + "listcomp.alloc_size", + ) + .unwrap(); + list = allocate_list( + generator, + ctx, + Some(elem_ty), + list_alloc_size.into_int_value(), + Some("listcomp.addr"), + ); - let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); - ctx.builder - .build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap()) - .unwrap(); + let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); + ctx.builder + .build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap()) + .unwrap(); - ctx.builder - .build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb) - .unwrap(); + ctx.builder + .build_conditional_branch( + gen_in_range_check(ctx, start, stop, step), + test_bb, + cont_bb, + ) + .unwrap(); - ctx.builder.position_at_end(test_bb); - // add and test - let tmp = ctx - .builder - .build_int_add( - ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), - step, - "start_loop", - ) - .unwrap(); - ctx.builder.build_store(i, tmp).unwrap(); - ctx.builder - .build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb) - .unwrap(); + ctx.builder.position_at_end(test_bb); + // add and test + let tmp = ctx + .builder + .build_int_add( + ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), + step, + "start_loop", + ) + .unwrap(); + ctx.builder.build_store(i, tmp).unwrap(); + ctx.builder + .build_conditional_branch( + gen_in_range_check(ctx, tmp, stop, step), + body_bb, + cont_bb, + ) + .unwrap(); - ctx.builder.position_at_end(body_bb); - } else { - let length = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero_size_t, int32.const_int(1, false)], - Some("length"), - ) - .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + ctx.builder.position_at_end(body_bb); + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let length = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, int32.const_int(1, false)], + Some("length"), + ) + .into_int_value(); + list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); - let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; - // counter = -1 - ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap(); - ctx.builder.build_unconditional_branch(test_bb).unwrap(); + let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; + // counter = -1 + ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap(); + ctx.builder.build_unconditional_branch(test_bb).unwrap(); - ctx.builder.position_at_end(test_bb); - let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap(); - let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); - ctx.builder.build_store(counter, tmp).unwrap(); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); - ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap(); + ctx.builder.position_at_end(test_bb); + let tmp = + ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap(); + let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); + ctx.builder.build_store(counter, tmp).unwrap(); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); + ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap(); - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero_size_t, zero_32], - Some("arr.addr"), - ) - .into_pointer_value(); - let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); - generator.gen_assign(ctx, target, val.into())?; + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, zero_32], + Some("arr.addr"), + ) + .into_pointer_value(); + let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); + generator.gen_assign(ctx, target, val.into())?; + } + _ => { + panic!( + "unsupported list comprehension iterator type: {}", + ctx.unifier.stringify(iter_ty) + ); + } } // Emits the content of `cont_bb` diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index ee581959..1130cc18 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -315,9 +315,6 @@ pub fn gen_for( let orelse_bb = if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; - // Whether the iterable is a range() expression - let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); - // The BB containing the increment expression let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); // The BB containing the loop condition check @@ -326,113 +323,132 @@ pub fn gen_for( // store loop bb information and restore it later let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); + let iter_ty = iter.custom.unwrap(); let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { - v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? + v.to_basic_value_enum(ctx, generator, iter_ty)? } else { return Ok(()); }; - if is_iterable_range_expr { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); - // Internal variable for loop; Cannot be assigned - let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; - // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed - let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? - else { - unreachable!() - }; - let (start, stop, step) = destructure_range(ctx, iter_val); - - ctx.builder.build_store(i, start).unwrap(); - - // Check "If step is zero, ValueError is raised." - let rangenez = - ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap(); - ctx.make_assert( - generator, - rangenez, - "ValueError", - "range() arg 3 must not be zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + match &*ctx.unifier.get_ty(iter_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { - ctx.builder.position_at_end(cond_bb); - ctx.builder - .build_conditional_branch( - gen_in_range_check( - ctx, - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - stop, - step, - ), - body_bb, - orelse_bb, + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + // Internal variable for loop; Cannot be assigned + let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; + // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed + let Some(target_i) = + generator.gen_store_target(ctx, target, Some("for.target.addr"))? + else { + unreachable!() + }; + let (start, stop, step) = destructure_range(ctx, iter_val); + + ctx.builder.build_store(i, start).unwrap(); + + // Check "If step is zero, ValueError is raised." + let rangenez = ctx + .builder + .build_int_compare(IntPredicate::NE, step, int32.const_zero(), "") + .unwrap(); + ctx.make_assert( + generator, + rangenez, + "ValueError", + "range() arg 3 must not be zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + + { + ctx.builder.position_at_end(cond_bb); + ctx.builder + .build_conditional_branch( + gen_in_range_check( + ctx, + ctx.builder + .build_load(i, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + stop, + step, + ), + body_bb, + orelse_bb, + ) + .unwrap(); + } + + ctx.builder.position_at_end(incr_bb); + let next_i = ctx + .builder + .build_int_add( + ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), + step, + "inc", ) .unwrap(); + ctx.builder.build_store(i, next_i).unwrap(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + + ctx.builder.position_at_end(body_bb); + ctx.builder + .build_store( + target_i, + ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), + ) + .unwrap(); + generator.gen_block(ctx, body.iter())?; } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; + ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap(); + let len = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero, int32.const_int(1, false)], + Some("len"), + ) + .into_int_value(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - ctx.builder.position_at_end(incr_bb); - let next_i = ctx - .builder - .build_int_add( - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - step, - "inc", - ) - .unwrap(); - ctx.builder.build_store(i, next_i).unwrap(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + ctx.builder.position_at_end(cond_bb); + let index = ctx + .builder + .build_load(index_addr, "for.index") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); + ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); - ctx.builder.position_at_end(body_bb); - ctx.builder - .build_store( - target_i, - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - ) - .unwrap(); - generator.gen_block(ctx, body.iter())?; - } else { - let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; - ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap(); - let len = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero, int32.const_int(1, false)], - Some("len"), - ) - .into_int_value(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + ctx.builder.position_at_end(incr_bb); + let index = + ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); + ctx.builder.build_store(index_addr, inc).unwrap(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - ctx.builder.position_at_end(cond_bb); - let index = ctx - .builder - .build_load(index_addr, "for.index") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); - ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) + .into_pointer_value(); + let index = ctx + .builder + .build_load(index_addr, "for.index") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); - ctx.builder.position_at_end(incr_bb); - let index = - ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); - ctx.builder.build_store(index_addr, inc).unwrap(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) - .into_pointer_value(); - let index = ctx - .builder - .build_load(index_addr, "for.index") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); - generator.gen_assign(ctx, target, val.into())?; - generator.gen_block(ctx, body.iter())?; + generator.gen_assign(ctx, target, val.into())?; + generator.gen_block(ctx, body.iter())?; + } + _ => { + panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty)); + } } for (k, (_, _, counter)) in &var_assignment { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index f4d3a62e..88ae4d41 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -100,16 +100,18 @@ pub struct Inferencer<'a> { pub in_handler: bool, } +type InferenceError = HashSet; + struct NaiveFolder(); impl Fold<()> for NaiveFolder { type TargetU = Option; - type Error = HashSet; + type Error = InferenceError; fn map_user(&mut self, (): ()) -> Result { Ok(None) } } -fn report_error(msg: &str, location: Location) -> Result> { +fn report_error(msg: &str, location: Location) -> Result { Err(HashSet::from([format!("{msg} at {location}")])) } @@ -117,13 +119,13 @@ fn report_type_error( kind: TypeErrorKind, loc: Option, unifier: &Unifier, -) -> Result> { +) -> Result { Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()])) } impl<'a> Fold<()> for Inferencer<'a> { type TargetU = Option; - type Error = HashSet; + type Error = InferenceError; fn map_user(&mut self, (): ()) -> Result { Ok(None) @@ -612,22 +614,22 @@ impl<'a> Fold<()> for Inferencer<'a> { } } -type InferenceResult = Result>; +type InferenceResult = Result; impl<'a> Inferencer<'a> { /// Constrain a <: b /// Currently implemented as unification - fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet> { + fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { self.unify(a, b, location) } - fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet> { + fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { self.unifier.unify(a, b).map_err(|e| { HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()]) }) } - fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet> { + fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), InferenceError> { match &pattern.node { ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { @@ -716,7 +718,7 @@ impl<'a> Inferencer<'a> { location: Location, args: Arguments, body: ast::Expr<()>, - ) -> Result>, HashSet> { + ) -> Result>, InferenceError> { if !args.posonlyargs.is_empty() || args.vararg.is_some() || !args.kwonlyargs.is_empty() @@ -787,7 +789,7 @@ impl<'a> Inferencer<'a> { location: Location, elt: ast::Expr<()>, mut generators: Vec, - ) -> Result>, HashSet> { + ) -> Result>, InferenceError> { if generators.len() != 1 { return report_error( "Only 1 generator statement for list comprehension is supported", @@ -893,7 +895,7 @@ impl<'a> Inferencer<'a> { id: StrRef, arg_index: usize, shape_expr: Located, - ) -> Result<(u64, ast::Expr>), HashSet> { + ) -> Result<(u64, ast::Expr>), InferenceError> { /* ### Further explanation @@ -1030,7 +1032,7 @@ impl<'a> Inferencer<'a> { func: &ast::Expr<()>, args: &mut Vec>, keywords: &[Located], - ) -> Result>>, HashSet> { + ) -> Result>>, InferenceError> { let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { return Ok(None); }; @@ -1588,7 +1590,7 @@ impl<'a> Inferencer<'a> { func: ast::Expr<()>, mut args: Vec>, keywords: Vec>, - ) -> Result>, HashSet> { + ) -> Result>, InferenceError> { if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {