forked from M-Labs/nac3
core: Implement list::__add__
This commit is contained in:
parent
c85e412206
commit
66c205275f
|
@ -1199,14 +1199,84 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if is_aug_assign || op != Operator::Mult {
|
||||
todo!("Only __mul__ is implemented for lists")
|
||||
if is_aug_assign {
|
||||
todo!("Augmented assignment operators not implemented for lists")
|
||||
}
|
||||
|
||||
match op {
|
||||
Operator::Add => {
|
||||
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||
|
||||
let elem_ty1 =
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
|
||||
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let elem_ty2 =
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
|
||||
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
|
||||
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
|
||||
|
||||
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
|
||||
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
let size = ctx
|
||||
.builder
|
||||
.build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "")
|
||||
.unwrap();
|
||||
|
||||
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
|
||||
|
||||
let lhs_len = ctx
|
||||
.builder
|
||||
.build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
|
||||
.unwrap();
|
||||
let rhs_len = ctx
|
||||
.builder
|
||||
.build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
|
||||
.unwrap();
|
||||
|
||||
let list_ptr = new_list.data().base_ptr(ctx, generator);
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
list_ptr,
|
||||
lhs.data().base_ptr(ctx, generator),
|
||||
lhs_len,
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
);
|
||||
|
||||
let list_ptr = unsafe {
|
||||
new_list.data().ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&lhs.load_size(ctx, None),
|
||||
None,
|
||||
)
|
||||
};
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
list_ptr,
|
||||
rhs.data().base_ptr(ctx, generator),
|
||||
rhs_len,
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
);
|
||||
|
||||
Ok(Some(new_list.as_base_value().into()))
|
||||
}
|
||||
|
||||
Operator::Mult => {
|
||||
let (elem_ty, list_val, int_val) =
|
||||
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
||||
let elem_ty =
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
|
||||
let elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(ty1)
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
unreachable!()
|
||||
|
@ -1214,8 +1284,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
|
||||
(elem_ty, left_val, right_val)
|
||||
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
||||
let elem_ty =
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
|
||||
let elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(ty2)
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
unreachable!()
|
||||
|
@ -1225,9 +1296,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
||||
let int_val =
|
||||
ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap();
|
||||
let list_val =
|
||||
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
||||
let int_val = ctx
|
||||
.builder
|
||||
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
// [...] * (i where i < 0) => []
|
||||
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
|
||||
|
||||
|
@ -1247,10 +1321,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
llvm_usize.const_zero(),
|
||||
(int_val, false),
|
||||
|generator, ctx, _, i| {
|
||||
let offset =
|
||||
ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap();
|
||||
let ptr =
|
||||
unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) };
|
||||
let offset = ctx
|
||||
.builder
|
||||
.build_int_mul(i, list_val.load_size(ctx, None), "")
|
||||
.unwrap();
|
||||
let ptr = unsafe {
|
||||
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
|
||||
};
|
||||
|
||||
let memcpy_sz = ctx
|
||||
.builder
|
||||
|
@ -1275,6 +1352,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
)?;
|
||||
|
||||
Ok(Some(new_list.as_base_value().into()))
|
||||
}
|
||||
|
||||
_ => todo!("Operator not supported"),
|
||||
}
|
||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
|
|
|
@ -437,7 +437,7 @@ pub fn typeof_binop(
|
|||
Ok(Some(match op {
|
||||
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
||||
if is_left_list || is_right_list {
|
||||
if op != Operator::Mult {
|
||||
if ![Operator::Add, Operator::Mult].contains(&op) {
|
||||
return Err(format!(
|
||||
"Binary operator {} not supported for list",
|
||||
binop_name(op)
|
||||
|
@ -665,6 +665,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
|
||||
/* list ======== */
|
||||
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
|
||||
|
||||
/* ndarray ===== */
|
||||
|
|
Loading…
Reference in New Issue