diff --git a/nac3artiq/demo/type_conversion.py b/nac3artiq/demo/type_conversion.py index 39a0d766..3639c794 100644 --- a/nac3artiq/demo/type_conversion.py +++ b/nac3artiq/demo/type_conversion.py @@ -35,6 +35,9 @@ class Demo: for z in enumerated_tuple3: z[0] z[1] + for (h, u) in enumerated_tuple2: + h + u for p in enumerated_list: p[0] p[1] @@ -44,6 +47,9 @@ class Demo: for r in enumerated_list3: r[0] r[1] + for (m, n) in enumerated_list2: + m + n def run(self): self.test() diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 85ed24b3..752e14c2 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -189,6 +189,35 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } .unwrap() } + ExprKind::Tuple { elts, .. } => { + let elts: Vec> = elts + .iter() + .map(|e| { + generator.gen_store_target(ctx, e, name).and_then(|v| { + v.ok_or_else(|| "failed to generate store target".to_string()) + }) + }) + .collect::>()?; + let struct_ty = + ctx.ctx.struct_type(&elts.iter().map(|p| p.get_type().into()).collect_vec(), false); + let struct_ptr = gen_var(ctx, struct_ty, name); + for (i, elt) in elts.iter().enumerate() { + ctx.builder + .build_store( + unsafe { + ctx.builder.build_in_bounds_gep( + struct_ptr, + &[ctx.i32.const_zero(), ctx.i32.const_int(i as u64, false)], + "", + ) + } + .unwrap(), + *elt, + ) + .unwrap(); + } + struct_ptr + } _ => codegen_unreachable!(ctx), })) } @@ -705,18 +734,20 @@ fn build_tuple_elem_switch<'ctx>( /// * `element_ty` - The type of the iterable elements, if known. /// * `length` - The length of the iterable. /// * `start` - The starting index for enumeration. +/// * `target_expr` - The target expression to store the current element and/or the current index. /// * `target_i` - The pointer to store the current index. /// * `get_first_elem` - A closure that returns the first element of the iterable. /// * `get_next_elem` - A closure that returns the next element given the next index. /// * `body` - The body of the loop. /// * `orelse` - The `else` block of the loop. #[allow(clippy::too_many_arguments)] -fn gen_for_enumerate<'ctx, G, GetFirst, GetNext>( +fn gen_for_enumerate<'ctx, G, GetFirst, GetNext, U>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, element_ty: Option, length: IntValue<'ctx>, start: IntValue<'ctx>, + target_expr: &ExprKind, target_i: PointerValue<'ctx>, get_first_elem: GetFirst, get_next_elem: GetNext, @@ -758,9 +789,35 @@ where )) }, |generator, ctx, _, iv_pair| { - ctx.builder - .build_store(target_i, ctx.builder.build_load(iv_pair, "iv").unwrap()) - .unwrap(); + match target_expr { + ExprKind::Tuple { elts, .. } if elts.len() == 2 => { + let i = ctx.builder.build_struct_gep(iv_pair, 0, "i").unwrap(); + let i_val = ctx + .builder + .build_load(i, "i_val") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let ptr_1 = ctx.builder.build_struct_gep(target_i, 0, "tuple.0").unwrap(); + let addr_1 = + ctx.builder.build_load(ptr_1, "tuple.0.addr").unwrap().into_pointer_value(); + ctx.builder.build_store(addr_1, i_val).unwrap(); + let v = ctx.builder.build_struct_gep(iv_pair, 1, "v").unwrap(); + let v_val = ctx.builder.build_load(v, "").unwrap(); + let ptr_2 = ctx.builder.build_struct_gep(target_i, 1, "tuple.1").unwrap(); + let addr_2 = + ctx.builder.build_load(ptr_2, "tuple.1.addr").unwrap().into_pointer_value(); + ctx.builder.build_store(addr_2, v_val).unwrap(); + } + ExprKind::Name { .. } => { + ctx.builder + .build_store(target_i, ctx.builder.build_load(iv_pair, "iv").unwrap()) + .unwrap(); + } + _ => codegen_unreachable!( + ctx, + "expected target expression of for enumerate to be a Name or a Tuple" + ), + } generator.gen_block(ctx, body.iter())?; Ok(()) }, @@ -932,6 +989,7 @@ pub fn gen_for( element_ty, length, start, + &target.node, target_i, |ctx| { iterable.data(ctx).get_unchecked( @@ -964,6 +1022,7 @@ pub fn gen_for( element_ty, length, start, + &target.node, target_i, |ctx| iterable.extract(ctx, 0), |ctx, next_i| { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d7d14b9e..924ed618 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -276,7 +276,7 @@ impl Fold<()> for Inferencer<'_> { let target = self.fold_expr(*target)?; let iter = self.fold_expr(*iter)?; - let element_ty = self.check_iterable(&iter)?; + let element_ty = self.check_iterable(&iter, &target)?; self.unify(element_ty, target.custom.unwrap(), &target.location)?; let body = @@ -605,6 +605,7 @@ impl Inferencer<'_> { fn get_iterable_element_type( &mut self, iter: &Expr>, + target: &Expr>, ) -> Result, InferenceError> { let iter_ty = iter.custom.unwrap(); let ty_enum = self.unifier.get_ty(iter_ty); @@ -641,10 +642,30 @@ impl Inferencer<'_> { _ => primitives.int32, }; - let resulting_ty = self.unifier.add_ty(TypeEnum::TTuple { - ty: vec![primitives.int32, inner_elem_ty], - is_vararg_ctx: false, - }); + let resulting_ty = match &target.node { + ExprKind::Tuple { elts, .. } if elts.len() == 2 => { + let idx_target_ty = elts[0].custom.unwrap(); + let val_target_ty = elts[1].custom.unwrap(); + self.unify(primitives.int32, idx_target_ty, &target.location)?; + self.unify(inner_elem_ty, val_target_ty, &target.location)?; + self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![idx_target_ty, val_target_ty], + is_vararg_ctx: false, + }) + } + ExprKind::Name { .. } => self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![primitives.int32, inner_elem_ty], + is_vararg_ctx: false, + }), + _ => { + let iter_ty = iter.custom.unwrap(); + let iter_ty_str = self.unifier.stringify(iter_ty); + return report_error( + &format!("cannot unpack '{iter_ty_str}' object (expected 2 items)"), + iter.location, + ); + } + }; Ok(Some(resulting_ty)) } @@ -693,10 +714,14 @@ impl Inferencer<'_> { } /// Check if a type is iterable. Returns `Ok(())` if iterable, otherwise returns an error. - fn check_iterable(&mut self, iter: &Expr>) -> Result { + fn check_iterable( + &mut self, + iter: &Expr>, + target: &Expr>, + ) -> Result { let iter_ty = iter.custom.unwrap(); let location = iter.location; - if let Some(elem_ty) = self.get_iterable_element_type(iter)? { + if let Some(elem_ty) = self.get_iterable_element_type(iter, target)? { Ok(elem_ty) } else { let iter_ty_str = self.unifier.stringify(iter_ty); @@ -1148,7 +1173,7 @@ impl Inferencer<'_> { let iterable = self.fold_expr(args.remove(0))?; let iterable_ty = iterable.custom.unwrap(); - self.check_iterable(&iterable)?; + self.check_iterable(&iterable, &promoted_func)?; args_new.push(iterable); if !args.is_empty() { diff --git a/nac3standalone/demo/src/enumerate.py b/nac3standalone/demo/src/enumerate.py new file mode 100644 index 00000000..465e07e8 --- /dev/null +++ b/nac3standalone/demo/src/enumerate.py @@ -0,0 +1,120 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_str(s: str): + ... + +# ---- Tuple enumeration tests ---- +def test_tuple_basic(): + for a, b in enumerate((1, 2, 3, 4)): + output_int32(a) + output_int32(b) + +def test_tuple_with_start(): + for c, d in enumerate((10, 20, 30), 5): + output_int32(c) + output_int32(d) + +def test_tuple_single_element(): + for e, f in enumerate((42,), 7): + output_int32(e) + output_int32(f) + + +# ---- List enumeration tests ---- +def test_list_basic(): + for g, h in enumerate([5, 6, 7, 8]): + output_int32(g) + output_int32(h) + +def test_list_with_start(): + for i, j in enumerate([100, 200, 300], 3): + output_int32(i) + output_int32(j) + +def test_list_single_element(): + for k, l in enumerate([99], 2): + output_int32(k) + output_int32(l) + + +# ---- Empty containers ---- +def test_empty_tuple(): + for m, n in enumerate((), 1): + output_int32(m) + +def test_empty_list(): + for o, p in enumerate([], 4): + output_int32(o) + + +# ---- Nested tuple elements in list ---- +def test_list_of_tuples(): + for q in enumerate([(2, 3), (6, 7), (8, 9)], 1): + output_int32(q[0]) + output_int32(q[1][0]) + output_int32(q[1][1]) + + +# ---- Iterating over previously defined variables ---- +def test_variable_tuple(): + my_tuple = (11, 12, 13, 14) + for r, s in enumerate(my_tuple): + output_int32(r) + output_int32(s) + +def test_variable_list(): + my_list = [21, 22, 23, 24] + for t, u in enumerate(my_list, 10): + output_int32(t) + output_int32(u) + + +# ---- Tuple unpacking ---- +def test_unpack_list_of_tuples(): + pairs = [(31, 32), (33, 34), (35, 36)] + for v in enumerate(pairs): + output_int32(v[0]) + output_int32(v[1][0]) + output_int32(v[1][1]) + +# ---- Enumerate with different types ---- +def test_different_types(): + mixed_list = [("a", 1), ("b", 2), ("c", 3)] + for w, x in enumerate(mixed_list): + output_int32(w) + output_str(x[0]) + output_int32(x[1]) + + +# ---- Main entry point ---- +def run() -> int32: + # simple tuple/list + test_tuple_basic() + test_tuple_with_start() + test_tuple_single_element() + + test_list_basic() + test_list_with_start() + test_list_single_element() + + # empty cases + test_empty_tuple() + test_empty_list() + + # tuples inside lists + test_list_of_tuples() + + # iteration over previously defined variables + test_variable_tuple() + test_variable_list() + + # unpacking tests + test_unpack_list_of_tuples() + + # different types + test_different_types() + + return 0