forked from M-Labs/nac3
1
0
Fork 0

core: Add ListValue and helper functions

This commit is contained in:
David Mak 2024-01-23 17:21:24 +08:00
parent f1581299fc
commit 5ee08b585f
6 changed files with 308 additions and 142 deletions

View File

@ -0,0 +1,225 @@
use inkwell::{
IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use crate::codegen::{CodeGenContext, CodeGenerator};
#[cfg(not(debug_assertions))]
pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {}
#[cfg(debug_assertions)]
pub fn assert_is_list<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) {
if let Err(msg) = ListValue::is_instance(value, llvm_usize) {
panic!("{msg}")
}
}
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>);
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
panic!("Expected struct type for `list` type, got {llvm_list_ty}")
};
if llvm_list_ty.count_fields() != 2 {
return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields()))
}
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
let Ok(_) = PointerType::try_from(list_size_ty) else {
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"))
};
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"))
};
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!("Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()))
}
Ok(())
}
/// Creates an [ListValue] from a [PointerValue].
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_list(ptr, llvm_usize);
ListValue(ptr, name)
}
/// Returns the underlying [PointerValue] pointing to the `list` instance.
pub fn get_ptr(&self) -> PointerValue<'ctx> {
self.0
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
fn get_data_pptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.get_ptr(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
}
}
/// Returns the pointer to the field storing the size of this `list`.
fn get_size_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
}
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.get_data_pptr(ctx), data);
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
///
/// If `size` is [None], the size stored in the field of this instance is used instead.
pub fn create_data(
&self,
ctx: &CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: Option<IntValue<'ctx>>,
) {
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, ""));
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn get_data(&self) -> ListDataProxy<'ctx> {
ListDataProxy(self.clone())
}
/// Stores the `size` of this `list` into this instance.
pub fn store_size(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &dyn CodeGenerator,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
let psize = self.get_size_ptr(ctx);
ctx.builder.build_store(psize, size);
}
/// Returns the size of this `list` as a value.
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let psize = self.get_size_ptr(ctx);
let var_name = name
.map(|v| v.to_string())
.or_else(|| self.1.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder.build_load(psize, var_name.as_str()).into_int_value()
}
}
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
#[derive(Copy, Clone)]
pub struct ListDataProxy<'ctx>(ListValue<'ctx>);
impl<'ctx> ListDataProxy<'ctx> {
/// Returns the single-indirection pointer to the array.
pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.get_data_pptr(ctx), var_name.as_str()).into_pointer_value()
}
pub unsafe fn ptr_offset_unchecked(
&self,
ctx: &CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.get_ptr(ctx),
&[idx],
var_name.as_str(),
)
}
/// Returns the pointer to the data at the `idx`-th index.
pub fn ptr_offset(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let in_range = ctx.builder.build_int_compare(
IntPredicate::ULT,
idx,
self.0.load_size(ctx, None),
""
);
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe {
self.ptr_offset_unchecked(ctx, idx, name)
}
}
pub unsafe fn get_unchecked(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_unchecked(ctx, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the `idx`-th flattened index.
pub fn get(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset(ctx, generator, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
}

View File

@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{ use crate::{
codegen::{ codegen::{
classes::ListValue,
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, gen_in_range_check,
get_llvm_type, get_llvm_type,
@ -896,43 +897,26 @@ pub fn allocate_list<'ctx, G: CodeGenerator>(
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
length: IntValue<'ctx>, length: IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> ListValue<'ctx> {
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let i32_t = ctx.ctx.i32_type();
// List structure; type { ty*, size_t } // List structure; type { ty*, size_t }
let arr_ty = ctx.ctx let arr_ty = ctx.ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
let zero = ctx.ctx.i32_type().const_zero();
let arr_str_ptr = ctx.builder.build_alloca( let arr_str_ptr = ctx.builder.build_alloca(
arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str() arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str()
); );
let list = ListValue::from_ptr_val(arr_str_ptr, size_t, Some("list"));
unsafe { let length = ctx.builder.build_int_z_extend(
// Pointer to the `length` element of the list structure length,
let len_ptr = ctx.builder.build_in_bounds_gep( size_t,
arr_str_ptr, ""
&[zero, i32_t.const_int(1, false)], );
"" list.store_size(ctx, generator, length);
); list.create_data(ctx, ty, None);
let length = ctx.builder.build_int_z_extend(
length,
size_t,
""
);
ctx.builder.build_store(len_ptr, length);
// Pointer to the `data` element of the list structure list
let arr_ptr = ctx.builder.build_array_alloca(ty, length, "");
let ptr_to_arr = ctx.builder.build_in_bounds_gep(
arr_str_ptr,
&[zero, i32_t.const_zero()],
""
);
ctx.builder.build_store(ptr_to_arr, arr_ptr);
}
arr_str_ptr
} }
/// Generates LLVM IR for a [list comprehension expression][expr]. /// Generates LLVM IR for a [list comprehension expression][expr].
@ -1006,8 +990,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
list_alloc_size.into_int_value(), list_alloc_size.into_int_value(),
Some("listcomp.addr") Some("listcomp.addr")
); );
list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr")) list_content = list.get_data().get_ptr(ctx);
.into_pointer_value();
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init"));
@ -1042,8 +1025,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
) )
.into_int_value(); .into_int_value();
list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp")); list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp"));
list_content = list_content = list.get_data().get_ptr(ctx);
ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value();
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1 // counter = -1
ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true));
@ -1065,12 +1047,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`
let emit_cont_bb = |ctx: &CodeGenContext| { let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| {
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
let len_ptr = unsafe { list.store_size(ctx, generator, ctx.builder.build_load(index, "index").into_int_value());
ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length")
};
ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index"));
}; };
for cond in ifs { for cond in ifs {
@ -1079,7 +1058,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} else { } else {
// Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the
// no element matches the predicate // no element matches the predicate
emit_cont_bb(ctx); emit_cont_bb(ctx, generator, list);
return Ok(None) return Ok(None)
}; };
@ -1092,7 +1071,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
let Some(elem) = generator.gen_expr(ctx, elt)? else { let Some(elem) = generator.gen_expr(ctx, elt)? else {
// Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents
emit_cont_bb(ctx); emit_cont_bb(ctx, generator, list);
return Ok(None) return Ok(None)
}; };
@ -1104,9 +1083,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc")); .build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc"));
ctx.builder.build_unconditional_branch(test_bb); ctx.builder.build_unconditional_branch(test_bb);
emit_cont_bb(ctx); emit_cont_bb(ctx, generator, list);
Ok(Some(list.into())) Ok(Some(list.get_ptr().into()))
} }
/// Generates LLVM IR for a [binary operator expression][expr]. /// Generates LLVM IR for a [binary operator expression][expr].
@ -1226,6 +1205,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
ctx.current_loc = expr.location; ctx.current_loc = expr.location;
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let usize = generator.get_size_type(ctx.ctx);
let zero = int32.const_int(0, false); let zero = int32.const_int(0, false);
let loc = ctx.debug_info.0.create_debug_location( let loc = ctx.debug_info.0.create_debug_location(
@ -1296,19 +1276,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}; };
let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false);
let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list"));
let arr_ptr = ctx.build_gep_and_load(arr_str_ptr, &[zero, zero], Some("list.ptr.addr")) let arr_ptr = arr_str_ptr.get_data();
.into_pointer_value(); for (i, v) in elements.iter().enumerate() {
unsafe { let elem_ptr = arr_ptr
for (i, v) in elements.iter().enumerate() { .ptr_offset(ctx, generator, usize.const_int(i as u64, false), Some("elem_ptr"));
let elem_ptr = ctx.builder.build_gep( ctx.builder.build_store(elem_ptr, *v);
arr_ptr,
&[int32.const_int(i as u64, false)],
"elem_ptr",
);
ctx.builder.build_store(elem_ptr, *v);
}
} }
arr_str_ptr.into() arr_str_ptr.get_ptr().into()
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let elements_val = elts let elements_val = elts
@ -1758,9 +1732,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
return Ok(None) return Ok(None)
}; };
let v = ListValue::from_ptr_val(v, usize, Some("arr"));
let ty = ctx.get_llvm_type(generator, *ty); let ty = ctx.get_llvm_type(generator, *ty);
let arr_ptr = ctx.build_gep_and_load(v, &[zero, zero], Some("arr.addr"))
.into_pointer_value();
if let ExprKind::Slice { lower, upper, step } = &slice.node { if let ExprKind::Slice { lower, upper, step } = &slice.node {
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let Some((start, end, step)) = let Some((start, end, step)) =
@ -1800,11 +1773,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v, v,
(start, end, step), (start, end, step),
); );
res_array_ret.into() res_array_ret.get_ptr().into()
} else { } else {
let len = ctx let len = v.load_size(ctx, Some("len"));
.build_gep_and_load(v, &[zero, int32.const_int(1, false)], Some("len"))
.into_int_value();
let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else { } else {
@ -1843,7 +1814,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
[Some(raw_index), Some(len), None], [Some(raw_index), Some(len), None],
expr.location, expr.location,
); );
ctx.build_gep_and_load(arr_ptr, &[index], None).into() v.get_data().get(ctx, generator, index, None).into()
} }
} }
TypeEnum::TNDArray { .. } => { TypeEnum::TNDArray { .. } => {

View File

@ -1,6 +1,11 @@
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator}; use super::{
classes::ListValue,
assert_is_ndarray,
CodeGenContext,
CodeGenerator,
};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -158,12 +163,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
step: &Option<Box<Expr<Option<Type>>>>, step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G, generator: &mut G,
list: PointerValue<'ctx>, list: ListValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> { ) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value(); let length = list.load_size(ctx, Some("length"));
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32"); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32");
Ok(Some(match (start, end, step) { Ok(Some(match (start, end, step) {
(s, e, None) => ( (s, e, None) => (
@ -295,9 +300,9 @@ pub fn list_slice_assignment<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
dest_arr: PointerValue<'ctx>, dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: PointerValue<'ctx>, src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) { ) {
let size_ty = generator.get_size_type(ctx.ctx); let size_ty = generator.get_size_type(ctx.ctx);
@ -326,21 +331,21 @@ pub fn list_slice_assignment<'ctx>(
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr")); let dest_arr_ptr = dest_arr.get_data().get_ptr(ctx);
let dest_arr_ptr = ctx.builder.build_pointer_cast( let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr.into_pointer_value(), dest_arr_ptr,
elem_ptr_type, elem_ptr_type,
"dest_arr_ptr_cast", "dest_arr_ptr_cast",
); );
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value(); let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32"); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32");
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr")); let src_arr_ptr = src_arr.get_data().get_ptr(ctx);
let src_arr_ptr = ctx.builder.build_pointer_cast( let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr.into_pointer_value(), src_arr_ptr,
elem_ptr_type, elem_ptr_type,
"src_arr_ptr_cast", "src_arr_ptr_cast",
); );
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value(); let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32"); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
// index in bound and positive should be done // index in bound and positive should be done
@ -443,9 +448,8 @@ pub fn list_slice_assignment<'ctx>(
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb);
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") };
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len"); let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len");
ctx.builder.build_store(dest_len_ptr, new_len); dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }
@ -604,11 +608,8 @@ pub fn call_ndarray_init_dims<'ctx>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: PointerValue<'ctx>, ndarray: PointerValue<'ctx>,
shape: PointerValue<'ctx>, shape: ListValue<'ctx>,
) { ) {
assert_is_ndarray(ndarray);
assert_is_list(shape);
let llvm_void = ctx.ctx.void_type(); let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -616,6 +617,8 @@ pub fn call_ndarray_init_dims<'ctx>(
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_is_ndarray(ndarray);
let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_init_dims", 32 => "__nac3_ndarray_init_dims",
64 => "__nac3_ndarray_init_dims64", 64 => "__nac3_ndarray_init_dims64",
@ -639,11 +642,7 @@ pub fn call_ndarray_init_dims<'ctx>(
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None, None,
); );
let shape_data = ctx.build_gep_and_load( let shape_data = shape.get_data();
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
);
let ndarray_num_dims = ctx.build_gep_and_load( let ndarray_num_dims = ctx.build_gep_and_load(
ndarray, ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()], &[llvm_i32.const_zero(), llvm_i32.const_zero()],
@ -654,7 +653,7 @@ pub fn call_ndarray_init_dims<'ctx>(
ndarray_init_dims_fn, ndarray_init_dims_fn,
&[ &[
ndarray_dims.into(), ndarray_dims.into(),
shape_data.into(), shape_data.get_ptr(ctx).into(),
ndarray_num_dims.into(), ndarray_num_dims.into(),
], ],
"", "",

View File

@ -37,6 +37,7 @@ use std::thread;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
use inkwell::types::AnyTypeEnum; use inkwell::types::AnyTypeEnum;
pub mod classes;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
mod generator; mod generator;
@ -999,22 +1000,6 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
} }
/// Checks whether the pointer `value` refers to a `list` in LLVM.
fn assert_is_list(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_shape_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
};
assert_eq!(llvm_shape_ty.count_fields(), 2);
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
}
value
}
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. /// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
fn assert_is_ndarray(value: PointerValue) -> PointerValue { fn assert_is_ndarray(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]

View File

@ -6,6 +6,7 @@ use super::{
}; };
use crate::{ use crate::{
codegen::{ codegen::{
classes::ListValue,
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
@ -92,6 +93,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
name: Option<&str>, name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String> { ) -> Result<Option<PointerValue<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
// very similar to gen_expr, but we don't do an extra load at the end // very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples // and we flatten nested tuples
Ok(Some(match &pattern.node { Ok(Some(match &pattern.node {
@ -132,16 +135,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
TypeEnum::TList { .. } => { TypeEnum::TList { .. } => {
let i32_type = ctx.ctx.i32_type();
let zero = i32_type.const_zero();
let v = generator let v = generator
.gen_expr(ctx, value)? .gen_expr(ctx, value)?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, value.custom.unwrap())? .to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value(); .into_pointer_value();
let len = ctx let v = ListValue::from_ptr_val(v, llvm_usize, None);
.build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) let len = v.load_size(ctx, Some("len"));
.into_int_value();
let raw_index = generator let raw_index = generator
.gen_expr(ctx, slice)? .gen_expr(ctx, slice)?
.unwrap() .unwrap()
@ -180,12 +180,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
[Some(raw_index), Some(len), None], [Some(raw_index), Some(len), None],
slice.location, slice.location,
); );
unsafe { v.get_data().ptr_offset(ctx, generator, index, name)
let arr_ptr = ctx
.build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr"))
.into_pointer_value();
ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or(""))
}
} }
TypeEnum::TNDArray { .. } => { TypeEnum::TNDArray { .. } => {
@ -206,6 +201,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
match &target.node { match &target.node {
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let BasicValueEnum::StructValue(v) = let BasicValueEnum::StructValue(v) =
@ -233,6 +230,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())? .to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value(); .into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) = let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls)? else { handle_slice_indices(lower, upper, step, ctx, generator, ls)? else {
return Ok(()) return Ok(())
@ -240,9 +238,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let value = value let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())? .to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value(); .into_pointer_value();
let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { let value = ListValue::from_ptr_val(value, llvm_usize, None);
unreachable!() let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
}; unreachable!()
};
let ty = ctx.get_llvm_type(generator, *ty); let ty = ctx.get_llvm_type(generator, *ty);
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {

View File

@ -3,6 +3,7 @@ use inkwell::values::{ArrayValue, IntValue};
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use crate::{ use crate::{
codegen::{ codegen::{
classes::ListValue,
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
irrt::{ irrt::{
@ -212,7 +213,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
shape: PointerValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
@ -239,29 +240,15 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
let i = ctx.builder let i = ctx.builder
.build_load(i_addr, "") .build_load(i_addr, "")
.into_int_value(); .into_int_value();
let shape_len = ctx.build_gep_and_load( let shape_len = shape.load_size(ctx, None);
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
).into_int_value();
Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, "")) Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, ""))
}, },
|generator, ctx, i_addr| { |generator, ctx, i_addr| {
let shape_elems = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
).into_pointer_value();
let i = ctx.builder let i = ctx.builder
.build_load(i_addr, "") .build_load(i_addr, "")
.into_int_value(); .into_int_value();
let shape_dim = ctx.build_gep_and_load( let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value();
shape_elems,
&[i],
None
).into_int_value();
let shape_dim_gez = ctx.builder.build_int_compare( let shape_dim_gez = ctx.builder.build_int_compare(
IntPredicate::SGE, IntPredicate::SGE,
@ -298,11 +285,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
None, None,
)?; )?;
let num_dims = ctx.build_gep_and_load( let num_dims = shape.load_size(ctx, None);
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None
).into_int_value();
let ndarray_num_dims = unsafe { let ndarray_num_dims = unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -507,7 +490,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
shape: PointerValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
let supported_types = [ let supported_types = [
ctx.primitives.int32, ctx.primitives.int32,
@ -543,7 +526,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
shape: PointerValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
let supported_types = [ let supported_types = [
ctx.primitives.int32, ctx.primitives.int32,
@ -579,7 +562,7 @@ fn call_ndarray_full_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
shape: PointerValue<'ctx>, shape: ListValue<'ctx>,
fill_value: BasicValueEnum<'ctx>, fill_value: BasicValueEnum<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
@ -725,6 +708,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone() let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?; .to_basic_value_enum(context, generator, shape_ty)?;
@ -733,7 +717,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
generator, generator,
context, context,
context.primitives.float, context.primitives.float,
shape_arg.into_pointer_value(), ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
) )
} }
@ -748,6 +732,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone() let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?; .to_basic_value_enum(context, generator, shape_ty)?;
@ -756,7 +741,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>(
generator, generator,
context, context,
context.primitives.float, context.primitives.float,
shape_arg.into_pointer_value(), ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
) )
} }
@ -771,6 +756,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone() let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?; .to_basic_value_enum(context, generator, shape_ty)?;
@ -779,7 +765,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>(
generator, generator,
context, context,
context.primitives.float, context.primitives.float,
shape_arg.into_pointer_value(), ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
) )
} }
@ -794,6 +780,7 @@ pub fn gen_ndarray_full<'ctx, 'a>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 2); assert_eq!(args.len(), 2);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone() let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?; .to_basic_value_enum(context, generator, shape_ty)?;
@ -805,7 +792,7 @@ pub fn gen_ndarray_full<'ctx, 'a>(
generator, generator,
context, context,
fill_value_ty, fill_value_ty,
shape_arg.into_pointer_value(), ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
fill_value_arg, fill_value_arg,
) )
} }