diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c8276d49..435243a8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1199,82 +1199,163 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( { let llvm_usize = generator.get_size_type(ctx.ctx); - if is_aug_assign || op != Operator::Mult { - todo!("Only __mul__ is implemented for lists") + if is_aug_assign { + todo!("Augmented assignment operators not implemented for lists") } - let (elem_ty, list_val, int_val) = - if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { - let elem_ty = + match op { + Operator::Add => { + debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id())); + debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id())); + + let elem_ty1 = if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) { - *params.iter().next().unwrap().1 + ctx.unifier.get_representative(*params.iter().next().unwrap().1) } else { unreachable!() }; - - (elem_ty, left_val, right_val) - } else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { - let elem_ty = + let elem_ty2 = if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) { - *params.iter().next().unwrap().1 + ctx.unifier.get_representative(*params.iter().next().unwrap().1) } else { unreachable!() }; + debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2)); - (elem_ty, right_val, left_val) - } else { - unreachable!() - }; - let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); - let int_val = - ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap(); - // [...] * (i where i < 0) => [] - let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1); - let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); + let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); + let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); - let new_list = allocate_list( - generator, - ctx, - Some(elem_llvm_ty), - ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), - None, - ); - - gen_for_callback_incrementing( - generator, - ctx, - llvm_usize.const_zero(), - (int_val, false), - |generator, ctx, _, i| { - let offset = - ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap(); - let ptr = - unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) }; - - let memcpy_sz = ctx + let size = ctx .builder - .build_int_mul( - list_val.load_size(ctx, None), - elem_llvm_ty.size_of().unwrap(), - "", - ) + .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); + let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); + + let lhs_len = ctx + .builder + .build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "") + .unwrap(); + let rhs_len = ctx + .builder + .build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "") + .unwrap(); + + let list_ptr = new_list.data().base_ptr(ctx, generator); call_memcpy_generic( ctx, - ptr, - list_val.data().base_ptr(ctx, generator), - memcpy_sz, + list_ptr, + lhs.data().base_ptr(ctx, generator), + lhs_len, ctx.ctx.bool_type().const_zero(), ); - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + let list_ptr = unsafe { + new_list.data().ptr_offset_unchecked( + ctx, + generator, + &lhs.load_size(ctx, None), + None, + ) + }; + call_memcpy_generic( + ctx, + list_ptr, + rhs.data().base_ptr(ctx, generator), + rhs_len, + ctx.ctx.bool_type().const_zero(), + ); - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_base_value().into())) + } + + Operator::Mult => { + let (elem_ty, list_val, int_val) = + if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { + let elem_ty = if let TypeEnum::TObj { params, .. } = + &*ctx.unifier.get_ty_immutable(ty1) + { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + + (elem_ty, left_val, right_val) + } else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { + let elem_ty = if let TypeEnum::TObj { params, .. } = + &*ctx.unifier.get_ty_immutable(ty2) + { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + + (elem_ty, right_val, left_val) + } else { + unreachable!() + }; + let list_val = + ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); + let int_val = ctx + .builder + .build_int_s_extend(int_val.into_int_value(), llvm_usize, "") + .unwrap(); + // [...] * (i where i < 0) => [] + let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); + + let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); + + let new_list = allocate_list( + generator, + ctx, + Some(elem_llvm_ty), + ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), + None, + ); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (int_val, false), + |generator, ctx, _, i| { + let offset = ctx + .builder + .build_int_mul(i, list_val.load_size(ctx, None), "") + .unwrap(); + let ptr = unsafe { + new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) + }; + + let memcpy_sz = ctx + .builder + .build_int_mul( + list_val.load_size(ctx, None), + elem_llvm_ty.size_of().unwrap(), + "", + ) + .unwrap(); + + call_memcpy_generic( + ctx, + ptr, + list_val.data().base_ptr(ctx, generator), + memcpy_sz, + ctx.ctx.bool_type().const_zero(), + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(Some(new_list.as_base_value().into())) + } + + _ => todo!("Operator not supported"), + } } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index f2464c03..cbf87077 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -437,7 +437,7 @@ pub fn typeof_binop( Ok(Some(match op { Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { if is_left_list || is_right_list { - if op != Operator::Mult { + if ![Operator::Add, Operator::Mult].contains(&op) { return Err(format!( "Binary operator {} not supported for list", binop_name(op) @@ -665,6 +665,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* list ======== */ + impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); /* ndarray ===== */