forked from M-Labs/nac3
1
0
Fork 0

core: Apply clippy suggestions

This commit is contained in:
David Mak 2024-02-20 18:07:55 +08:00
parent 8492503af2
commit 49de81ef1e
15 changed files with 206 additions and 167 deletions

View File

@ -65,8 +65,8 @@ enum Isa {
impl Isa { impl Isa {
/// Returns the number of bits in `size_t` for the [`Isa`]. /// Returns the number of bits in `size_t` for the [`Isa`].
fn get_size_type(&self) -> u32 { fn get_size_type(self) -> u32 {
if self == &Isa::Host { if self == Isa::Host {
64u32 64u32
} else { } else {
32u32 32u32

View File

@ -662,10 +662,10 @@ impl InnerResolver {
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty { match actual_ty {
Ok(t) => match unifier.unify(*ty, t) { Ok(t) => match unifier.unify(*ty, t) {
Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))), Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
Err(e) => Ok(Err(format!( Err(e) => Ok(Err(format!(
"type error ({}) for the ndarray", "type error ({}) for the ndarray",
e.to_display(unifier).to_string() e.to_display(unifier),
))), ))),
}, },
Err(e) => Ok(Err(e)), Err(e) => Ok(Err(e)),

View File

@ -58,13 +58,15 @@ impl<'ctx> ListValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [ListValue] from a [PointerValue]. /// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_list(ptr, llvm_usize); assert_is_list(ptr, llvm_usize);
ListValue(ptr, name) ListValue(ptr, name)
} }
/// Returns the underlying [PointerValue] pointing to the `list` instance. /// Returns the underlying [`PointerValue`] pointing to the `list` instance.
#[must_use]
pub fn get_ptr(&self) -> PointerValue<'ctx> { pub fn get_ptr(&self) -> PointerValue<'ctx> {
self.0 self.0
} }
@ -119,8 +121,9 @@ impl<'ctx> ListValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
#[must_use]
pub fn get_data(&self) -> ListDataProxy<'ctx> { pub fn get_data(&self) -> ListDataProxy<'ctx> {
ListDataProxy(self.clone()) ListDataProxy(*self)
} }
/// Stores the `size` of this `list` into this instance. /// Stores the `size` of this `list` into this instance.
@ -140,7 +143,7 @@ impl<'ctx> ListValue<'ctx> {
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let psize = self.get_size_ptr(ctx); let psize = self.get_size_ptr(ctx);
let var_name = name let var_name = name
.map(|v| v.to_string()) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.size"))) .or_else(|| self.1.map(|v| format!("{v}.size")))
.unwrap_or_default(); .unwrap_or_default();
@ -164,6 +167,9 @@ impl<'ctx> ListDataProxy<'ctx> {
.unwrap() .unwrap()
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn ptr_offset_unchecked( pub unsafe fn ptr_offset_unchecked(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -211,6 +217,9 @@ impl<'ctx> ListDataProxy<'ctx> {
} }
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn get_unchecked( pub unsafe fn get_unchecked(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -271,13 +280,15 @@ impl<'ctx> RangeValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [RangeValue] from a [PointerValue]. /// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_range(ptr); assert_is_range(ptr);
RangeValue(ptr, name) RangeValue(ptr, name)
} }
/// Returns the underlying [PointerValue] pointing to the `range` instance. /// Returns the underlying [`PointerValue`] pointing to the `range` instance.
#[must_use]
pub fn get_ptr(&self) -> PointerValue<'ctx> { pub fn get_ptr(&self) -> PointerValue<'ctx> {
self.0 self.0
} }
@ -337,7 +348,7 @@ impl<'ctx> RangeValue<'ctx> {
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstart = self.get_start_ptr(ctx); let pstart = self.get_start_ptr(ctx);
let var_name = name let var_name = name
.map(|v| v.to_string()) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.start"))) .or_else(|| self.1.map(|v| format!("{v}.start")))
.unwrap_or_default(); .unwrap_or_default();
@ -362,7 +373,7 @@ impl<'ctx> RangeValue<'ctx> {
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pend = self.get_end_ptr(ctx); let pend = self.get_end_ptr(ctx);
let var_name = name let var_name = name
.map(|v| v.to_string()) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.end"))) .or_else(|| self.1.map(|v| format!("{v}.end")))
.unwrap_or_default(); .unwrap_or_default();
@ -387,7 +398,7 @@ impl<'ctx> RangeValue<'ctx> {
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstep = self.get_step_ptr(ctx); let pstep = self.get_step_ptr(ctx);
let var_name = name let var_name = name
.map(|v| v.to_string()) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.step"))) .or_else(|| self.1.map(|v| format!("{v}.step")))
.unwrap_or_default(); .unwrap_or_default();
@ -458,7 +469,8 @@ impl<'ctx> NDArrayValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [NDArrayValue] from a [PointerValue]. /// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val( pub fn from_ptr_val(
ptr: PointerValue<'ctx>, ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
@ -468,7 +480,8 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue(ptr, name) NDArrayValue(ptr, name)
} }
/// Returns the underlying [PointerValue] pointing to the `NDArray` instance. /// Returns the underlying [`PointerValue`] pointing to the `NDArray` instance.
#[must_use]
pub fn get_ptr(&self) -> PointerValue<'ctx> { pub fn get_ptr(&self) -> PointerValue<'ctx> {
self.0 self.0
} }
@ -539,8 +552,9 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> { pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> {
NDArrayDimsProxy(self.clone()) NDArrayDimsProxy(*self)
} }
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
@ -575,8 +589,9 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Returns a proxy object to the field storing the data of this `NDArray`. /// Returns a proxy object to the field storing the data of this `NDArray`.
#[must_use]
pub fn get_data(&self) -> NDArrayDataProxy<'ctx> { pub fn get_data(&self) -> NDArrayDataProxy<'ctx> {
NDArrayDataProxy(self.clone()) NDArrayDataProxy(*self)
} }
} }
@ -665,6 +680,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
.unwrap() .unwrap()
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn ptr_to_data_flattened_unchecked( pub unsafe fn ptr_to_data_flattened_unchecked(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -710,6 +728,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
} }
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn get_flattened_unchecked( pub unsafe fn get_flattened_unchecked(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -732,6 +753,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
} }
/// # Safety
///
/// This function should be called with valid indices.
pub unsafe fn ptr_offset_unchecked( pub unsafe fn ptr_offset_unchecked(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -750,7 +774,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx, ctx,
self.0, self.0,
indices, indices,
).unwrap(); );
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -761,6 +785,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
} }
} }
/// # Safety
///
/// This function should be called with valid indices.
pub unsafe fn ptr_offset_unchecked_const( pub unsafe fn ptr_offset_unchecked_const(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -773,7 +800,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx, ctx,
self.0, self.0,
indices, indices,
).unwrap(); );
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -953,6 +980,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
} }
} }
/// # Safety
///
/// This function should be called with valid indices.
pub unsafe fn get_unsafe_const( pub unsafe fn get_unsafe_const(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -964,6 +994,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
} }
/// # Safety
///
/// This function should be called with valid indices.
pub unsafe fn get_unsafe( pub unsafe fn get_unsafe(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -1180,22 +1180,21 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
} }
}; };
let signature = match ctx.calls.get(&loc.into()) { let signature = if let Some(call) = ctx.calls.get(&loc.into()) {
Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), ctx.unifier.get_call_signature(*call).unwrap()
None => { } else {
let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap()); let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap());
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else { let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
unreachable!("must be tobj") unreachable!("must be tobj")
}; };
let fn_ty = fields.get(&op_name).unwrap().0; let fn_ty = fields.get(&op_name).unwrap().0;
let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty); let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty);
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else {
unreachable!() unreachable!()
}; };
sig.clone() sig.clone()
},
}; };
let fun_id = { let fun_id = {
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
@ -1380,7 +1379,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
.unwrap(), .unwrap(),
ctx.builder ctx.builder
.build_int_mul( .build_int_mul(
ndarray_num_dims.into(), ndarray_num_dims,
llvm_usize.size_of(), llvm_usize.size_of(),
"", "",
) )
@ -1426,7 +1425,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
.unwrap(), .unwrap(),
ctx.builder ctx.builder
.build_int_mul( .build_int_mul(
ndarray_num_elems.into(), ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(), llvm_ndarray_data_t.size_of().unwrap(),
"", "",
) )
@ -2078,7 +2077,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
*ty, *ty,
*ndims, *ndims,
v, v,
&*slice, slice,
) )
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {

View File

@ -94,9 +94,9 @@ pub trait CodeGenerator {
/// Allocate memory for a variable and return a pointer pointing to it. /// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function. /// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx, 'a>( fn gen_array_var_alloc<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>, size: IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,

View File

@ -569,11 +569,11 @@ pub fn call_j0<'ctx>(
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size. /// calculated total size.
/// ///
/// * `num_dims` - An [IntValue] containing the number of dimensions. /// * `num_dims` - An [`IntValue`] containing the number of dimensions.
/// * `dims` - A [PointerValue] to an array containing the size of each dimensions. /// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
pub fn call_ndarray_calc_size<'ctx>( pub fn call_ndarray_calc_size<'ctx>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -619,9 +619,9 @@ pub fn call_ndarray_calc_size<'ctx>(
/// Generates a call to `__nac3_ndarray_init_dims`. /// Generates a call to `__nac3_ndarray_init_dims`.
/// ///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. /// `NDArray`.
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM /// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM
/// representation of a `list`. /// representation of a `list`.
pub fn call_ndarray_init_dims<'ctx>( pub fn call_ndarray_init_dims<'ctx>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
@ -674,14 +674,14 @@ pub fn call_ndarray_init_dims<'ctx>(
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. /// Generates a call to `__nac3_ndarray_calc_nd_indices`.
/// ///
/// * `index` - The index to compute the multidimensional index for. /// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. /// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx>( pub fn call_ndarray_calc_nd_indices<'ctx>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>, index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> PointerValue<'ctx> {
let llvm_void = ctx.ctx.void_type(); let llvm_void = ctx.ctx.void_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -728,7 +728,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
) )
.unwrap(); .unwrap();
Ok(indices) indices
} }
fn call_ndarray_flatten_index_impl<'ctx>( fn call_ndarray_flatten_index_impl<'ctx>(
@ -737,7 +737,7 @@ fn call_ndarray_flatten_index_impl<'ctx>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>, indices: PointerValue<'ctx>,
indices_size: IntValue<'ctx>, indices_size: IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String> { ) -> IntValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -746,7 +746,7 @@ fn call_ndarray_flatten_index_impl<'ctx>(
debug_assert_eq!( debug_assert_eq!(
IntType::try_from(indices.get_type().get_element_type()) IntType::try_from(indices.get_type().get_element_type())
.map(|itype| itype.get_bit_width()) .map(IntType::get_bit_width)
.unwrap_or_default(), .unwrap_or_default(),
llvm_i32.get_bit_width(), llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
@ -795,13 +795,13 @@ fn call_ndarray_flatten_index_impl<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap(); .unwrap();
Ok(index) index
} }
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index. /// multidimensional index.
/// ///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. /// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for. /// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>( pub fn call_ndarray_flatten_index<'ctx>(
@ -809,7 +809,7 @@ pub fn call_ndarray_flatten_index<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>, indices: ListValue<'ctx>,
) -> Result<IntValue<'ctx>, String> { ) -> IntValue<'ctx> {
let indices_size = indices.load_size(ctx, None); let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data(); let indices_data = indices.get_data();
@ -824,7 +824,7 @@ pub fn call_ndarray_flatten_index<'ctx>(
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index. /// multidimensional index.
/// ///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. /// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for. /// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index_const<'ctx>( pub fn call_ndarray_flatten_index_const<'ctx>(
@ -832,7 +832,7 @@ pub fn call_ndarray_flatten_index_const<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: ArrayValue<'ctx>, indices: ArrayValue<'ctx>,
) -> Result<IntValue<'ctx>, String> { ) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.get_type().len(); let indices_size = indices.get_type().len();
@ -841,7 +841,7 @@ pub fn call_ndarray_flatten_index_const<'ctx>(
indices.get_type().get_element_type(), indices.get_type().get_element_type(),
llvm_usize.const_int(indices_size as u64, false), llvm_usize.const_int(indices_size as u64, false),
None None
)?; ).unwrap();
for i in 0..indices_size { for i in 0..indices_size {
let v = ctx.builder.build_extract_value(indices, i, "") let v = ctx.builder.build_extract_value(indices, i, "")
.unwrap() .unwrap()

View File

@ -408,6 +408,7 @@ pub struct CodeGenTask {
/// ///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable /// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`. /// would be represented by an `i8`.
#[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx>( fn get_llvm_type<'ctx>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
@ -543,6 +544,7 @@ fn get_llvm_type<'ctx>(
/// ABI representation is that the in-memory representation must be at least byte-sized and must /// ABI representation is that the in-memory representation must be at least byte-sized and must
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such /// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations. /// restriction for ABI representations.
#[allow(clippy::too_many_arguments)]
fn get_llvm_abi_type<'ctx>( fn get_llvm_abi_type<'ctx>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
@ -809,7 +811,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
/* filename */ /* filename */
&task &task
.body .body
.get(0) .first()
.map_or_else( .map_or_else(
|| "<nac3_internal>".to_string(), || "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(), |f| f.location.file.0.to_string(),
@ -839,7 +841,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
inkwell::debug_info::DIFlags::PUBLIC, inkwell::debug_info::DIFlags::PUBLIC,
); );
let (row, col) = let (row, col) =
task.body.get(0).map_or_else(|| (0, 0), |b| (b.location.row, b.location.column)); task.body.first().map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
let func_scope: DISubprogram<'_> = dibuilder.create_function( let func_scope: DISubprogram<'_> = dibuilder.create_function(
/* scope */ compile_unit.as_debug_info_scope(), /* scope */ compile_unit.as_debug_info_scope(),
/* func name */ symbol, /* func name */ symbol,

View File

@ -55,7 +55,7 @@ pub fn gen_var<'ctx>(
Ok(ptr) Ok(ptr)
} }
/// See [CodeGenerator::gen_array_var_alloc]. /// See [`CodeGenerator::gen_array_var_alloc`].
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>( pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
ty: T, ty: T,
@ -484,7 +484,7 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let init_bb = ctx.ctx.append_basic_block(current, "for.init"); let init_bb = ctx.ctx.append_basic_block(current, "for.init");
// The BB containing the loop condition check // The BB containing the loop condition check
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");

View File

@ -119,7 +119,6 @@ impl SymbolValue {
/// * `constant` - The constant to create the value from. /// * `constant` - The constant to create the value from.
pub fn from_constant_inferred( pub fn from_constant_inferred(
constant: &Constant, constant: &Constant,
unifier: &mut Unifier
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => Ok(SymbolValue::OptionNone), Constant::None => Ok(SymbolValue::OptionNone),
@ -140,7 +139,7 @@ impl SymbolValue {
Constant::Tuple(t) => { Constant::Tuple(t) => {
let elems = t let elems = t
.iter() .iter()
.map(|constant| Self::from_constant_inferred(constant, unifier)) .map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?; .collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems)) Ok(SymbolValue::Tuple(elems))
} }
@ -507,7 +506,7 @@ pub fn parse_type_annotation<T>(
let values = if let Tuple { elts, .. } = &slice.node { let values = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter()
.map(|elt| parse_literal(elt)) .map(&mut parse_literal)
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
} else { } else {
vec![parse_literal(slice)?] vec![parse_literal(slice)?]
@ -577,8 +576,8 @@ pub fn parse_type_annotation<T>(
])) ]))
} }
} }
Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier) Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone()))) .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])), .map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([ _ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location), format!("unsupported type expression at {}", expr.location),

View File

@ -18,6 +18,7 @@ use crate::{
gen_ndarray_empty, gen_ndarray_empty,
gen_ndarray_eye, gen_ndarray_eye,
gen_ndarray_full, gen_ndarray_full,
gen_ndarray_identity,
gen_ndarray_ones, gen_ndarray_ones,
gen_ndarray_zeros, gen_ndarray_zeros,
}, },
@ -30,7 +31,6 @@ use inkwell::{
IntPredicate IntPredicate
}; };
use itertools::Either; use itertools::Either;
use crate::toplevel::numpy::gen_ndarray_identity;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>; type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
@ -903,7 +903,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// type variable // type variable
&[(list_int32, "shape")], &[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator) gen_ndarray_empty(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
@ -916,7 +916,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// type variable // type variable
&[(list_int32, "shape")], &[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator) gen_ndarray_empty(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
@ -929,7 +929,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// type variable // type variable
&[(list_int32, "shape")], &[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_zeros(ctx, obj, fun, args, generator) gen_ndarray_zeros(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
@ -942,7 +942,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// type variable // type variable
&[(list_int32, "shape")], &[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_ones(ctx, obj, fun, args, generator) gen_ndarray_ones(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
@ -958,7 +958,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// type variable // type variable
&[(list_int32, "shape"), (tv, "fill_value")], &[(list_int32, "shape"), (tv, "fill_value")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, obj, fun, args, generator) gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
) )
@ -980,13 +980,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: ndarray_float_2d, ret: ndarray_float_2d,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_eye(ctx, obj, fun, args, generator) gen_ndarray_eye(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}, },
)))), )))),
@ -999,7 +999,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ndarray_float_2d, ndarray_float_2d,
&[(int32, "n")], &[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, obj, fun, args, generator) gen_ndarray_identity(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
@ -1527,14 +1527,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
None, None,
).into_int_value(); ).into_int_value();
if len.get_type().get_bit_width() != 32 { if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(ctx.builder Some(ctx.builder
.build_int_truncate(len, llvm_i32, "len") .build_int_truncate(len, llvm_i32, "len")
.map(Into::into) .map(Into::into)
.unwrap() .unwrap()
) )
} else {
Some(len.into())
} }
} }
_ => unreachable!(), _ => unreachable!(),

View File

@ -1350,9 +1350,9 @@ impl TopLevelComposer {
])) ]))
} }
} }
ast::StmtKind::Assign { .. } => {}, // we don't class attributes ast::StmtKind::Assign { .. } // we don't class attributes
ast::StmtKind::Pass { .. } => {} | ast::StmtKind::Expr { value: _, .. } // typically a docstring; ignoring all expressions matches CPython behavior
ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior | ast::StmtKind::Pass { .. } => {}
_ => { _ => {
return Err(HashSet::from([ return Err(HashSet::from([
format!( format!(

View File

@ -21,10 +21,10 @@ use crate::{
/// Creates an `NDArray` instance from a constant shape. /// Creates an `NDArray` instance from a constant shape.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue]. /// * `shape` - The shape of the `NDArray`, represented as an LLVM [`ArrayValue`].
fn create_ndarray_const_shape<'ctx, 'a>( fn create_ndarray_const_shape<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
shape: ArrayValue<'ctx> shape: ArrayValue<'ctx>
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
@ -94,9 +94,9 @@ fn create_ndarray_const_shape<'ctx, 'a>(
Ok(ndarray) Ok(ndarray)
} }
fn ndarray_zero_value<'ctx, 'a>( fn ndarray_zero_value<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
@ -108,15 +108,15 @@ fn ndarray_zero_value<'ctx, 'a>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into() ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into() ctx.gen_string(generator, "")
} else { } else {
unreachable!() unreachable!()
} }
} }
fn ndarray_one_value<'ctx, 'a>( fn ndarray_one_value<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
@ -130,7 +130,7 @@ fn ndarray_one_value<'ctx, 'a>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into() ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into() ctx.gen_string(generator, "1")
} else { } else {
unreachable!() unreachable!()
} }
@ -138,11 +138,11 @@ fn ndarray_one_value<'ctx, 'a>(
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. /// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the NDArray. /// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_empty_impl<'ctx, 'a>( fn call_ndarray_empty_impl<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
shape: ListValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
@ -308,14 +308,14 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
/// ///
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements /// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
/// with the given value (as opposed to all elements within the array). /// with the given value (as opposed to all elements within the array).
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>( fn ndarray_fill_indexed<'ctx, ValueFn>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn, value_fn: ValueFn,
) -> Result<(), String> ) -> Result<(), String>
where where
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>, ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{ {
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
@ -327,7 +327,7 @@ fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
ctx, ctx,
idx, idx,
ndarray, ndarray,
)?; );
value_fn(generator, ctx, indices) value_fn(generator, ctx, indices)
} }
@ -336,11 +336,11 @@ fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the NDArray. /// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_zeros_impl<'ctx, 'a>( fn call_ndarray_zeros_impl<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
shape: ListValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
@ -372,11 +372,11 @@ fn call_ndarray_zeros_impl<'ctx, 'a>(
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. /// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the NDArray. /// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_ones_impl<'ctx, 'a>( fn call_ndarray_ones_impl<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
shape: ListValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
@ -408,11 +408,11 @@ fn call_ndarray_ones_impl<'ctx, 'a>(
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. /// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the NDArray. /// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_full_impl<'ctx, 'a>( fn call_ndarray_full_impl<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
shape: ListValue<'ctx>, shape: ListValue<'ctx>,
fill_value: BasicValueEnum<'ctx>, fill_value: BasicValueEnum<'ctx>,
@ -465,7 +465,7 @@ fn call_ndarray_full_impl<'ctx, 'a>(
copy.into() copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() { } else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value.into() fill_value
} else { } else {
unreachable!() unreachable!()
}; };
@ -479,10 +479,10 @@ fn call_ndarray_full_impl<'ctx, 'a>(
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. /// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
fn call_ndarray_eye_impl<'ctx, 'a>( fn call_ndarray_eye_impl<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
nrows: IntValue<'ctx>, nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>, ncols: IntValue<'ctx>,
@ -552,11 +552,11 @@ fn call_ndarray_eye_impl<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.empty`. /// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx, 'a>( pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());
@ -576,11 +576,11 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.zeros`. /// Generates LLVM IR for `ndarray.zeros`.
pub fn gen_ndarray_zeros<'ctx, 'a>( pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());
@ -600,11 +600,11 @@ pub fn gen_ndarray_zeros<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.ones`. /// Generates LLVM IR for `ndarray.ones`.
pub fn gen_ndarray_ones<'ctx, 'a>( pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());
@ -624,11 +624,11 @@ pub fn gen_ndarray_ones<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.full`. /// Generates LLVM IR for `ndarray.full`.
pub fn gen_ndarray_full<'ctx, 'a>( pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());
@ -652,11 +652,11 @@ pub fn gen_ndarray_full<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.eye`. /// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx, 'a>( pub fn gen_ndarray_eye<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());
@ -668,7 +668,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
let ncols_ty = fun.0.args[1].ty; let ncols_ty = fun.0.args[1].ty;
let ncols_arg = args.iter() let ncols_arg = args.iter()
.find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false)) .find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)) .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
.unwrap_or_else(|| { .unwrap_or_else(|| {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty) args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
@ -676,7 +676,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
let offset_ty = fun.0.args[2].ty; let offset_ty = fun.0.args[2].ty;
let offset_arg = args.iter() let offset_arg = args.iter()
.find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false)) .find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty)) .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
.unwrap_or_else(|| { .unwrap_or_else(|| {
Ok(context.gen_symbol_val( Ok(context.gen_symbol_val(
@ -697,11 +697,11 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
} }
/// Generates LLVM IR for `ndarray.identity`. /// Generates LLVM IR for `ndarray.identity`.
pub fn gen_ndarray_identity<'ctx, 'a>( pub fn gen_ndarray_identity<'ctx>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none()); assert!(obj.is_none());

View File

@ -519,7 +519,7 @@ pub fn get_type_from_type_annotation_kinds(
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => { TypeAnnotation::Literal(values) => {
let values = values.iter() let values = values.iter()
.map(|v| SymbolValue::from_constant_inferred(v, unifier)) .map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?; .map_err(|err| HashSet::from([err]))?;

View File

@ -52,6 +52,7 @@ pub struct PrimitiveStore {
impl PrimitiveStore { impl PrimitiveStore {
/// Returns a [Type] representing `size_t`. /// Returns a [Type] representing `size_t`.
#[must_use]
pub fn usize(&self) -> Type { pub fn usize(&self) -> Type {
match self.size_t { match self.size_t {
32 => self.uint32, 32 => self.uint32,
@ -1074,6 +1075,7 @@ impl<'a> Inferencer<'a> {
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
} }
#[allow(clippy::unnecessary_wraps)]
fn infer_identifier(&mut self, id: StrRef) -> InferenceResult { fn infer_identifier(&mut self, id: StrRef) -> InferenceResult {
Ok(if let Some(ty) = self.variable_mapping.get(&id) { Ok(if let Some(ty) = self.variable_mapping.get(&id) {
*ty *ty
@ -1126,6 +1128,7 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TList { ty })) Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
} }
#[allow(clippy::unnecessary_wraps)]
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult { fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = elts.iter().map(|x| x.custom.unwrap()).collect(); let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
@ -1242,18 +1245,18 @@ impl<'a> Inferencer<'a> {
&mut self, &mut self,
value: &ast::Expr<Option<Type>>, value: &ast::Expr<Option<Type>>,
dummy_tvar: Type, dummy_tvar: Type,
ndims: &Type, ndims: Type,
) -> InferenceResult { ) -> InferenceResult {
debug_assert!(matches!( debug_assert!(matches!(
&*self.unifier.get_ty_immutable(dummy_tvar), &*self.unifier.get_ty_immutable(dummy_tvar),
TypeEnum::TVar { is_const_generic: false, .. } TypeEnum::TVar { is_const_generic: false, .. }
)); ));
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims }); let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims });
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else { let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims)) panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(ndims))
}; };
let ndims = values.iter() let ndims = values.iter()
@ -1320,7 +1323,7 @@ impl<'a> Inferencer<'a> {
} }
ExprKind::Constant { value: ast::Constant::Int(val), .. } => { ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, ty, *ndims)
} else { } else {
// the index is a constant, so value can be a sequence. // the index is a constant, so value can be a sequence.
let ind: Option<i32> = (*val).try_into().ok(); let ind: Option<i32> = (*val).try_into().ok();
@ -1350,7 +1353,7 @@ impl<'a> Inferencer<'a> {
} }
TypeEnum::TNDArray { ndims, .. } => { TypeEnum::TNDArray { ndims, .. } => {
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?; self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, ty, *ndims)
} }
_ => unreachable!(), _ => unreachable!(),
} }

View File

@ -206,7 +206,7 @@ impl TypeEnum {
} }
} }
/// Returns a [TypeEnum] representing a generic `ndarray` type. /// Returns a [`TypeEnum`] representing a generic `ndarray` type.
/// ///
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic. /// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic. /// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
@ -262,13 +262,13 @@ impl Unifier {
} }
} }
/// Sets the [PrimitiveStore] instance within this `Unifier`. /// Sets the [`PrimitiveStore`] instance within this `Unifier`.
/// ///
/// This function can only be invoked once. Any subsequent invocations will result in an /// This function can only be invoked once. Any subsequent invocations will result in an
/// assertion error.. /// assertion error.
pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) { pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) {
assert!(self.primitive_store.is_none()); assert!(self.primitive_store.is_none());
self.primitive_store.replace(primitives.clone()); self.primitive_store.replace(*primitives);
} }
pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> { pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> {
@ -496,18 +496,22 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } | TLiteral { .. } => true, TRigidVar { .. }
| TLiteral { .. }
// functions are instantiated for each call sites, so the function type can contain
// type variables.
| TFunc { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty }
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars), | TVirtual { ty }
| TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
} }
// functions are instantiated for each call sites, so the function type can contain
// type variables.
TFunc { .. } => true,
} }
} }
@ -748,8 +752,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) | (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty } | TNDArray { ty, .. }) => {
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => {
for (k, v) in fields { for (k, v) in fields {
match *k { match *k {
RecordKey::Int(_) => { RecordKey::Int(_) => {
@ -795,7 +798,7 @@ impl Unifier {
SymbolValue::I64(v) => v as i128, SymbolValue::I64(v) => v as i128,
SymbolValue::U32(v) => v as i128, SymbolValue::U32(v) => v as i128,
SymbolValue::U64(v) => v as i128, SymbolValue::U64(v) => v as i128,
_ => return self.incompatible_types(a, b), _ => return Self::incompatible_types(a, b),
}; };
let can_convert = if self.unioned(ty, primitives.int32) { let can_convert = if self.unioned(ty, primitives.int32) {
@ -811,7 +814,7 @@ impl Unifier {
}; };
if !can_convert { if !can_convert {
return self.incompatible_types(a, b) return Self::incompatible_types(a, b)
} }
} }
@ -836,7 +839,7 @@ impl Unifier {
let v2i = symbol_value_to_int(v2); let v2i = symbol_value_to_int(v2);
if v1i != v2i { if v1i != v2i {
return self.incompatible_types(a, b) return Self::incompatible_types(a, b)
} }
} }
} }
@ -863,10 +866,10 @@ impl Unifier {
} }
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
if self.unify_impl(*ty1, *ty2, false).is_err() { if self.unify_impl(*ty1, *ty2, false).is_err() {
return self.incompatible_types(a, b) return Self::incompatible_types(a, b)
} }
if self.unify_impl(*ndims1, *ndims2, false).is_err() { if self.unify_impl(*ndims1, *ndims2, false).is_err() {
return self.incompatible_types(a, b) return Self::incompatible_types(a, b)
} }
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
@ -945,7 +948,7 @@ impl Unifier {
TObj { obj_id: id2, params: params2, .. }, TObj { obj_id: id2, params: params2, .. },
) => { ) => {
if id1 != id2 { if id1 != id2 {
self.incompatible_types(a, b)?; Self::incompatible_types(a, b)?;
} }
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits // Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
@ -1015,7 +1018,7 @@ impl Unifier {
} }
_ => { _ => {
if swapped { if swapped {
return self.incompatible_types(a, b); return Self::incompatible_types(a, b);
} }
self.unify_impl(b, a, true)?; self.unify_impl(b, a, true)?;
@ -1179,7 +1182,7 @@ impl Unifier {
table.set_value(a, ty_b); table.set_value(a, ty_b);
} }
fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> { fn incompatible_types(a: Type, b: Type) -> Result<(), TypeError> {
Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)) Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))
} }
@ -1326,7 +1329,7 @@ impl Unifier {
None None
} }
} }
_ => { TypeEnum::TCall(_) => {
unreachable!("{} not expected", ty.get_type_name()) unreachable!("{} not expected", ty.get_type_name())
} }
} }