forked from M-Labs/nac3
core: Implement list::__add__
This commit is contained in:
parent
c85e412206
commit
66c205275f
|
@ -1199,82 +1199,163 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if is_aug_assign || op != Operator::Mult {
|
if is_aug_assign {
|
||||||
todo!("Only __mul__ is implemented for lists")
|
todo!("Augmented assignment operators not implemented for lists")
|
||||||
}
|
}
|
||||||
|
|
||||||
let (elem_ty, list_val, int_val) =
|
match op {
|
||||||
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
Operator::Add => {
|
||||||
let elem_ty =
|
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||||
|
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||||
|
|
||||||
|
let elem_ty1 =
|
||||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
|
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
|
||||||
*params.iter().next().unwrap().1
|
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
let elem_ty2 =
|
||||||
(elem_ty, left_val, right_val)
|
|
||||||
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
|
||||||
let elem_ty =
|
|
||||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
|
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
|
||||||
*params.iter().next().unwrap().1
|
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
|
||||||
|
|
||||||
(elem_ty, right_val, left_val)
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
|
||||||
} else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
|
||||||
let int_val =
|
|
||||||
ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap();
|
|
||||||
// [...] * (i where i < 0) => []
|
|
||||||
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
|
|
||||||
|
|
||||||
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
|
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
let new_list = allocate_list(
|
let size = ctx
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
Some(elem_llvm_ty),
|
|
||||||
ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(int_val, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let offset =
|
|
||||||
ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap();
|
|
||||||
let ptr =
|
|
||||||
unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) };
|
|
||||||
|
|
||||||
let memcpy_sz = ctx
|
|
||||||
.builder
|
.builder
|
||||||
.build_int_mul(
|
.build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "")
|
||||||
list_val.load_size(ctx, None),
|
|
||||||
elem_llvm_ty.size_of().unwrap(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
|
||||||
|
|
||||||
|
let lhs_len = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
|
||||||
|
.unwrap();
|
||||||
|
let rhs_len = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let list_ptr = new_list.data().base_ptr(ctx, generator);
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
ctx,
|
ctx,
|
||||||
ptr,
|
list_ptr,
|
||||||
list_val.data().base_ptr(ctx, generator),
|
lhs.data().base_ptr(ctx, generator),
|
||||||
memcpy_sz,
|
lhs_len,
|
||||||
ctx.ctx.bool_type().const_zero(),
|
ctx.ctx.bool_type().const_zero(),
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
let list_ptr = unsafe {
|
||||||
},
|
new_list.data().ptr_offset_unchecked(
|
||||||
llvm_usize.const_int(1, false),
|
ctx,
|
||||||
)?;
|
generator,
|
||||||
|
&lhs.load_size(ctx, None),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
list_ptr,
|
||||||
|
rhs.data().base_ptr(ctx, generator),
|
||||||
|
rhs_len,
|
||||||
|
ctx.ctx.bool_type().const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Some(new_list.as_base_value().into()))
|
Ok(Some(new_list.as_base_value().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::Mult => {
|
||||||
|
let (elem_ty, list_val, int_val) =
|
||||||
|
if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
||||||
|
let elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||||
|
&*ctx.unifier.get_ty_immutable(ty1)
|
||||||
|
{
|
||||||
|
*params.iter().next().unwrap().1
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
(elem_ty, left_val, right_val)
|
||||||
|
} else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) {
|
||||||
|
let elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||||
|
&*ctx.unifier.get_ty_immutable(ty2)
|
||||||
|
{
|
||||||
|
*params.iter().next().unwrap().1
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
(elem_ty, right_val, left_val)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let list_val =
|
||||||
|
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
let int_val = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
// [...] * (i where i < 0) => []
|
||||||
|
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
|
||||||
|
|
||||||
|
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let new_list = allocate_list(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some(elem_llvm_ty),
|
||||||
|
ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(int_val, false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
let offset = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(i, list_val.load_size(ctx, None), "")
|
||||||
|
.unwrap();
|
||||||
|
let ptr = unsafe {
|
||||||
|
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let memcpy_sz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(
|
||||||
|
list_val.load_size(ctx, None),
|
||||||
|
elem_llvm_ty.size_of().unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ptr,
|
||||||
|
list_val.data().base_ptr(ctx, generator),
|
||||||
|
memcpy_sz,
|
||||||
|
ctx.ctx.bool_type().const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(new_list.as_base_value().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => todo!("Operator not supported"),
|
||||||
|
}
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
|
|
|
@ -437,7 +437,7 @@ pub fn typeof_binop(
|
||||||
Ok(Some(match op {
|
Ok(Some(match op {
|
||||||
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
||||||
if is_left_list || is_right_list {
|
if is_left_list || is_right_list {
|
||||||
if op != Operator::Mult {
|
if ![Operator::Add, Operator::Mult].contains(&op) {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Binary operator {} not supported for list",
|
"Binary operator {} not supported for list",
|
||||||
binop_name(op)
|
binop_name(op)
|
||||||
|
@ -665,6 +665,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||||
|
|
||||||
/* list ======== */
|
/* list ======== */
|
||||||
|
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||||
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
|
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
|
||||||
|
|
||||||
/* ndarray ===== */
|
/* ndarray ===== */
|
||||||
|
|
Loading…
Reference in New Issue