1
0
forked from M-Labs/nac3

core: Implement list::__add__

This commit is contained in:
David Mak 2024-07-04 15:50:19 +08:00
parent c85e412206
commit 66c205275f
2 changed files with 137 additions and 55 deletions

View File

@ -1199,82 +1199,163 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{
let llvm_usize = generator.get_size_type(ctx.ctx);
if is_aug_assign || op != Operator::Mult {
todo!("Only __mul__ is implemented for lists")
if is_aug_assign {
todo!("Augmented assignment operators not implemented for lists")
}
let (elem_ty, list_val, int_val) =
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty =
match op {
Operator::Add => {
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
let elem_ty1 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
*params.iter().next().unwrap().1
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
};
(elem_ty, left_val, right_val)
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty =
let elem_ty2 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
*params.iter().next().unwrap().1
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
};
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
(elem_ty, right_val, left_val)
} else {
unreachable!()
};
let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
let int_val =
ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap();
// [...] * (i where i < 0) => []
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
let new_list = allocate_list(
generator,
ctx,
Some(elem_llvm_ty),
ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(),
None,
);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(int_val, false),
|generator, ctx, _, i| {
let offset =
ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap();
let ptr =
unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) };
let memcpy_sz = ctx
let size = ctx
.builder
.build_int_mul(
list_val.load_size(ctx, None),
elem_llvm_ty.size_of().unwrap(),
"",
)
.build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "")
.unwrap();
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
let lhs_len = ctx
.builder
.build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
.unwrap();
let rhs_len = ctx
.builder
.build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
.unwrap();
let list_ptr = new_list.data().base_ptr(ctx, generator);
call_memcpy_generic(
ctx,
ptr,
list_val.data().base_ptr(ctx, generator),
memcpy_sz,
list_ptr,
lhs.data().base_ptr(ctx, generator),
lhs_len,
ctx.ctx.bool_type().const_zero(),
);
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let list_ptr = unsafe {
new_list.data().ptr_offset_unchecked(
ctx,
generator,
&lhs.load_size(ctx, None),
None,
)
};
call_memcpy_generic(
ctx,
list_ptr,
rhs.data().base_ptr(ctx, generator),
rhs_len,
ctx.ctx.bool_type().const_zero(),
);
Ok(Some(new_list.as_base_value().into()))
Ok(Some(new_list.as_base_value().into()))
}
Operator::Mult => {
let (elem_ty, list_val, int_val) =
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(ty1)
{
*params.iter().next().unwrap().1
} else {
unreachable!()
};
(elem_ty, left_val, right_val)
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
let elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(ty2)
{
*params.iter().next().unwrap().1
} else {
unreachable!()
};
(elem_ty, right_val, left_val)
} else {
unreachable!()
};
let list_val =
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
let int_val = ctx
.builder
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
.unwrap();
// [...] * (i where i < 0) => []
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
let new_list = allocate_list(
generator,
ctx,
Some(elem_llvm_ty),
ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(),
None,
);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(int_val, false),
|generator, ctx, _, i| {
let offset = ctx
.builder
.build_int_mul(i, list_val.load_size(ctx, None), "")
.unwrap();
let ptr = unsafe {
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
};
let memcpy_sz = ctx
.builder
.build_int_mul(
list_val.load_size(ctx, None),
elem_llvm_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ptr,
list_val.data().base_ptr(ctx, generator),
memcpy_sz,
ctx.ctx.bool_type().const_zero(),
);
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(Some(new_list.as_base_value().into()))
}
_ => todo!("Operator not supported"),
}
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{

View File

@ -437,7 +437,7 @@ pub fn typeof_binop(
Ok(Some(match op {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list {
if op != Operator::Mult {
if ![Operator::Add, Operator::Mult].contains(&op) {
return Err(format!(
"Binary operator {} not supported for list",
binop_name(op)
@ -665,6 +665,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
/* ndarray ===== */