forked from M-Labs/nac3
1
0
Fork 0

core: Implement list::__add__

This commit is contained in:
David Mak 2024-07-04 15:50:19 +08:00
parent c85e412206
commit 66c205275f
2 changed files with 137 additions and 55 deletions

View File

@ -1199,14 +1199,84 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if is_aug_assign || op != Operator::Mult { if is_aug_assign {
todo!("Only __mul__ is implemented for lists") todo!("Augmented assignment operators not implemented for lists")
} }
match op {
Operator::Add => {
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
let elem_ty1 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
};
let elem_ty2 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
};
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
let size = ctx
.builder
.build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "")
.unwrap();
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
let lhs_len = ctx
.builder
.build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
.unwrap();
let rhs_len = ctx
.builder
.build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
.unwrap();
let list_ptr = new_list.data().base_ptr(ctx, generator);
call_memcpy_generic(
ctx,
list_ptr,
lhs.data().base_ptr(ctx, generator),
lhs_len,
ctx.ctx.bool_type().const_zero(),
);
let list_ptr = unsafe {
new_list.data().ptr_offset_unchecked(
ctx,
generator,
&lhs.load_size(ctx, None),
None,
)
};
call_memcpy_generic(
ctx,
list_ptr,
rhs.data().base_ptr(ctx, generator),
rhs_len,
ctx.ctx.bool_type().const_zero(),
);
Ok(Some(new_list.as_base_value().into()))
}
Operator::Mult => {
let (elem_ty, list_val, int_val) = let (elem_ty, list_val, int_val) =
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty = let elem_ty = if let TypeEnum::TObj { params, .. } =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) { &*ctx.unifier.get_ty_immutable(ty1)
{
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() unreachable!()
@ -1214,8 +1284,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
(elem_ty, left_val, right_val) (elem_ty, left_val, right_val)
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { } else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty = let elem_ty = if let TypeEnum::TObj { params, .. } =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) { &*ctx.unifier.get_ty_immutable(ty2)
{
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() unreachable!()
@ -1225,9 +1296,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else { } else {
unreachable!() unreachable!()
}; };
let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); let list_val =
let int_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap(); let int_val = ctx
.builder
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
.unwrap();
// [...] * (i where i < 0) => [] // [...] * (i where i < 0) => []
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
@ -1247,10 +1321,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
llvm_usize.const_zero(), llvm_usize.const_zero(),
(int_val, false), (int_val, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
let offset = let offset = ctx
ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap(); .builder
let ptr = .build_int_mul(i, list_val.load_size(ctx, None), "")
unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) }; .unwrap();
let ptr = unsafe {
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
};
let memcpy_sz = ctx let memcpy_sz = ctx
.builder .builder
@ -1275,6 +1352,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
)?; )?;
Ok(Some(new_list.as_base_value().into())) Ok(Some(new_list.as_base_value().into()))
}
_ => todo!("Operator not supported"),
}
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {

View File

@ -437,7 +437,7 @@ pub fn typeof_binop(
Ok(Some(match op { Ok(Some(match op {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list { if is_left_list || is_right_list {
if op != Operator::Mult { if ![Operator::Add, Operator::Mult].contains(&op) {
return Err(format!( return Err(format!(
"Binary operator {} not supported for list", "Binary operator {} not supported for list",
binop_name(op) binop_name(op)
@ -665,6 +665,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* list ======== */ /* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
/* ndarray ===== */ /* ndarray ===== */