forked from M-Labs/nac3
core: Apply clippy suggestions
This commit is contained in:
parent
8492503af2
commit
49de81ef1e
|
@ -65,8 +65,8 @@ enum Isa {
|
||||||
|
|
||||||
impl Isa {
|
impl Isa {
|
||||||
/// Returns the number of bits in `size_t` for the [`Isa`].
|
/// Returns the number of bits in `size_t` for the [`Isa`].
|
||||||
fn get_size_type(&self) -> u32 {
|
fn get_size_type(self) -> u32 {
|
||||||
if self == &Isa::Host {
|
if self == Isa::Host {
|
||||||
64u32
|
64u32
|
||||||
} else {
|
} else {
|
||||||
32u32
|
32u32
|
||||||
|
|
|
@ -662,10 +662,10 @@ impl InnerResolver {
|
||||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||||
match actual_ty {
|
match actual_ty {
|
||||||
Ok(t) => match unifier.unify(*ty, t) {
|
Ok(t) => match unifier.unify(*ty, t) {
|
||||||
Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
|
Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
|
||||||
Err(e) => Ok(Err(format!(
|
Err(e) => Ok(Err(format!(
|
||||||
"type error ({}) for the ndarray",
|
"type error ({}) for the ndarray",
|
||||||
e.to_display(unifier).to_string()
|
e.to_display(unifier),
|
||||||
))),
|
))),
|
||||||
},
|
},
|
||||||
Err(e) => Ok(Err(e)),
|
Err(e) => Ok(Err(e)),
|
||||||
|
|
|
@ -58,13 +58,15 @@ impl<'ctx> ListValue<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [ListValue] from a [PointerValue].
|
/// Creates an [`ListValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
|
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
|
||||||
assert_is_list(ptr, llvm_usize);
|
assert_is_list(ptr, llvm_usize);
|
||||||
ListValue(ptr, name)
|
ListValue(ptr, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying [PointerValue] pointing to the `list` instance.
|
/// Returns the underlying [`PointerValue`] pointing to the `list` instance.
|
||||||
|
#[must_use]
|
||||||
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
|
@ -119,8 +121,9 @@ impl<'ctx> ListValue<'ctx> {
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
|
#[must_use]
|
||||||
pub fn get_data(&self) -> ListDataProxy<'ctx> {
|
pub fn get_data(&self) -> ListDataProxy<'ctx> {
|
||||||
ListDataProxy(self.clone())
|
ListDataProxy(*self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the `size` of this `list` into this instance.
|
/// Stores the `size` of this `list` into this instance.
|
||||||
|
@ -140,7 +143,7 @@ impl<'ctx> ListValue<'ctx> {
|
||||||
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
let psize = self.get_size_ptr(ctx);
|
let psize = self.get_size_ptr(ctx);
|
||||||
let var_name = name
|
let var_name = name
|
||||||
.map(|v| v.to_string())
|
.map(ToString::to_string)
|
||||||
.or_else(|| self.1.map(|v| format!("{v}.size")))
|
.or_else(|| self.1.map(|v| format!("{v}.size")))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
@ -164,6 +167,9 @@ impl<'ctx> ListDataProxy<'ctx> {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
pub unsafe fn ptr_offset_unchecked(
|
pub unsafe fn ptr_offset_unchecked(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
@ -211,6 +217,9 @@ impl<'ctx> ListDataProxy<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
pub unsafe fn get_unchecked(
|
pub unsafe fn get_unchecked(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -271,13 +280,15 @@ impl<'ctx> RangeValue<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [RangeValue] from a [PointerValue].
|
/// Creates an [`RangeValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
|
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
|
||||||
assert_is_range(ptr);
|
assert_is_range(ptr);
|
||||||
RangeValue(ptr, name)
|
RangeValue(ptr, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying [PointerValue] pointing to the `range` instance.
|
/// Returns the underlying [`PointerValue`] pointing to the `range` instance.
|
||||||
|
#[must_use]
|
||||||
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
|
@ -337,7 +348,7 @@ impl<'ctx> RangeValue<'ctx> {
|
||||||
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
let pstart = self.get_start_ptr(ctx);
|
let pstart = self.get_start_ptr(ctx);
|
||||||
let var_name = name
|
let var_name = name
|
||||||
.map(|v| v.to_string())
|
.map(ToString::to_string)
|
||||||
.or_else(|| self.1.map(|v| format!("{v}.start")))
|
.or_else(|| self.1.map(|v| format!("{v}.start")))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
@ -362,7 +373,7 @@ impl<'ctx> RangeValue<'ctx> {
|
||||||
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
let pend = self.get_end_ptr(ctx);
|
let pend = self.get_end_ptr(ctx);
|
||||||
let var_name = name
|
let var_name = name
|
||||||
.map(|v| v.to_string())
|
.map(ToString::to_string)
|
||||||
.or_else(|| self.1.map(|v| format!("{v}.end")))
|
.or_else(|| self.1.map(|v| format!("{v}.end")))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
@ -387,7 +398,7 @@ impl<'ctx> RangeValue<'ctx> {
|
||||||
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
let pstep = self.get_step_ptr(ctx);
|
let pstep = self.get_step_ptr(ctx);
|
||||||
let var_name = name
|
let var_name = name
|
||||||
.map(|v| v.to_string())
|
.map(ToString::to_string)
|
||||||
.or_else(|| self.1.map(|v| format!("{v}.step")))
|
.or_else(|| self.1.map(|v| format!("{v}.step")))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
@ -458,7 +469,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [NDArrayValue] from a [PointerValue].
|
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
pub fn from_ptr_val(
|
pub fn from_ptr_val(
|
||||||
ptr: PointerValue<'ctx>,
|
ptr: PointerValue<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
|
@ -468,7 +480,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
NDArrayValue(ptr, name)
|
NDArrayValue(ptr, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying [PointerValue] pointing to the `NDArray` instance.
|
/// Returns the underlying [`PointerValue`] pointing to the `NDArray` instance.
|
||||||
|
#[must_use]
|
||||||
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
|
@ -539,8 +552,9 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> {
|
pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> {
|
||||||
NDArrayDimsProxy(self.clone())
|
NDArrayDimsProxy(*self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
|
@ -575,8 +589,9 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
pub fn get_data(&self) -> NDArrayDataProxy<'ctx> {
|
pub fn get_data(&self) -> NDArrayDataProxy<'ctx> {
|
||||||
NDArrayDataProxy(self.clone())
|
NDArrayDataProxy(*self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,6 +680,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
pub unsafe fn ptr_to_data_flattened_unchecked(
|
pub unsafe fn ptr_to_data_flattened_unchecked(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
@ -710,6 +728,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
pub unsafe fn get_flattened_unchecked(
|
pub unsafe fn get_flattened_unchecked(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -732,6 +753,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with valid indices.
|
||||||
pub unsafe fn ptr_offset_unchecked(
|
pub unsafe fn ptr_offset_unchecked(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
@ -750,7 +774,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
ctx,
|
ctx,
|
||||||
self.0,
|
self.0,
|
||||||
indices,
|
indices,
|
||||||
).unwrap();
|
);
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
ctx.builder.build_in_bounds_gep(
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
@ -761,6 +785,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with valid indices.
|
||||||
pub unsafe fn ptr_offset_unchecked_const(
|
pub unsafe fn ptr_offset_unchecked_const(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -773,7 +800,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
ctx,
|
ctx,
|
||||||
self.0,
|
self.0,
|
||||||
indices,
|
indices,
|
||||||
).unwrap();
|
);
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
ctx.builder.build_in_bounds_gep(
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
@ -953,6 +980,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with valid indices.
|
||||||
pub unsafe fn get_unsafe_const(
|
pub unsafe fn get_unsafe_const(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -964,6 +994,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||||
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with valid indices.
|
||||||
pub unsafe fn get_unsafe(
|
pub unsafe fn get_unsafe(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
|
@ -1180,9 +1180,9 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let signature = match ctx.calls.get(&loc.into()) {
|
let signature = if let Some(call) = ctx.calls.get(&loc.into()) {
|
||||||
Some(call) => ctx.unifier.get_call_signature(*call).unwrap(),
|
ctx.unifier.get_call_signature(*call).unwrap()
|
||||||
None => {
|
} else {
|
||||||
let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap());
|
let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap());
|
||||||
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
|
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
|
||||||
unreachable!("must be tobj")
|
unreachable!("must be tobj")
|
||||||
|
@ -1195,7 +1195,6 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
||||||
};
|
};
|
||||||
|
|
||||||
sig.clone()
|
sig.clone()
|
||||||
},
|
|
||||||
};
|
};
|
||||||
let fun_id = {
|
let fun_id = {
|
||||||
let defs = ctx.top_level.definitions.read();
|
let defs = ctx.top_level.definitions.read();
|
||||||
|
@ -1380,7 +1379,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_mul(
|
.build_int_mul(
|
||||||
ndarray_num_dims.into(),
|
ndarray_num_dims,
|
||||||
llvm_usize.size_of(),
|
llvm_usize.size_of(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
@ -1426,7 +1425,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_mul(
|
.build_int_mul(
|
||||||
ndarray_num_elems.into(),
|
ndarray_num_elems,
|
||||||
llvm_ndarray_data_t.size_of().unwrap(),
|
llvm_ndarray_data_t.size_of().unwrap(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
@ -2078,7 +2077,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
*ty,
|
*ty,
|
||||||
*ndims,
|
*ndims,
|
||||||
v,
|
v,
|
||||||
&*slice,
|
slice,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
|
|
|
@ -94,9 +94,9 @@ pub trait CodeGenerator {
|
||||||
|
|
||||||
/// Allocate memory for a variable and return a pointer pointing to it.
|
/// Allocate memory for a variable and return a pointer pointing to it.
|
||||||
/// The default implementation places the allocations at the start of the function.
|
/// The default implementation places the allocations at the start of the function.
|
||||||
fn gen_array_var_alloc<'ctx, 'a>(
|
fn gen_array_var_alloc<'ctx>(
|
||||||
&mut self,
|
&mut self,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ty: BasicTypeEnum<'ctx>,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
size: IntValue<'ctx>,
|
size: IntValue<'ctx>,
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
|
|
|
@ -569,11 +569,11 @@ pub fn call_j0<'ctx>(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
||||||
///
|
///
|
||||||
/// * `num_dims` - An [IntValue] containing the number of dimensions.
|
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
|
||||||
/// * `dims` - A [PointerValue] to an array containing the size of each dimensions.
|
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
|
||||||
pub fn call_ndarray_calc_size<'ctx>(
|
pub fn call_ndarray_calc_size<'ctx>(
|
||||||
generator: &dyn CodeGenerator,
|
generator: &dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -619,9 +619,9 @@ pub fn call_ndarray_calc_size<'ctx>(
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_init_dims`.
|
/// Generates a call to `__nac3_ndarray_init_dims`.
|
||||||
///
|
///
|
||||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
|
/// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM
|
||||||
/// representation of a `list`.
|
/// representation of a `list`.
|
||||||
pub fn call_ndarray_init_dims<'ctx>(
|
pub fn call_ndarray_init_dims<'ctx>(
|
||||||
generator: &dyn CodeGenerator,
|
generator: &dyn CodeGenerator,
|
||||||
|
@ -674,14 +674,14 @@ pub fn call_ndarray_init_dims<'ctx>(
|
||||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
||||||
///
|
///
|
||||||
/// * `index` - The index to compute the multidimensional index for.
|
/// * `index` - The index to compute the multidimensional index for.
|
||||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
pub fn call_ndarray_calc_nd_indices<'ctx>(
|
pub fn call_ndarray_calc_nd_indices<'ctx>(
|
||||||
generator: &dyn CodeGenerator,
|
generator: &dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
index: IntValue<'ctx>,
|
index: IntValue<'ctx>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> PointerValue<'ctx> {
|
||||||
let llvm_void = ctx.ctx.void_type();
|
let llvm_void = ctx.ctx.void_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
@ -728,7 +728,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Ok(indices)
|
indices
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_ndarray_flatten_index_impl<'ctx>(
|
fn call_ndarray_flatten_index_impl<'ctx>(
|
||||||
|
@ -737,7 +737,7 @@ fn call_ndarray_flatten_index_impl<'ctx>(
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
indices: PointerValue<'ctx>,
|
indices: PointerValue<'ctx>,
|
||||||
indices_size: IntValue<'ctx>,
|
indices_size: IntValue<'ctx>,
|
||||||
) -> Result<IntValue<'ctx>, String> {
|
) -> IntValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
@ -746,7 +746,7 @@ fn call_ndarray_flatten_index_impl<'ctx>(
|
||||||
|
|
||||||
debug_assert_eq!(
|
debug_assert_eq!(
|
||||||
IntType::try_from(indices.get_type().get_element_type())
|
IntType::try_from(indices.get_type().get_element_type())
|
||||||
.map(|itype| itype.get_bit_width())
|
.map(IntType::get_bit_width)
|
||||||
.unwrap_or_default(),
|
.unwrap_or_default(),
|
||||||
llvm_i32.get_bit_width(),
|
llvm_i32.get_bit_width(),
|
||||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||||
|
@ -795,13 +795,13 @@ fn call_ndarray_flatten_index_impl<'ctx>(
|
||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Ok(index)
|
index
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
/// multidimensional index.
|
/// multidimensional index.
|
||||||
///
|
///
|
||||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
pub fn call_ndarray_flatten_index<'ctx>(
|
pub fn call_ndarray_flatten_index<'ctx>(
|
||||||
|
@ -809,7 +809,7 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
indices: ListValue<'ctx>,
|
indices: ListValue<'ctx>,
|
||||||
) -> Result<IntValue<'ctx>, String> {
|
) -> IntValue<'ctx> {
|
||||||
let indices_size = indices.load_size(ctx, None);
|
let indices_size = indices.load_size(ctx, None);
|
||||||
let indices_data = indices.get_data();
|
let indices_data = indices.get_data();
|
||||||
|
|
||||||
|
@ -824,7 +824,7 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
||||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
/// multidimensional index.
|
/// multidimensional index.
|
||||||
///
|
///
|
||||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
pub fn call_ndarray_flatten_index_const<'ctx>(
|
pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||||
|
@ -832,7 +832,7 @@ pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
indices: ArrayValue<'ctx>,
|
indices: ArrayValue<'ctx>,
|
||||||
) -> Result<IntValue<'ctx>, String> {
|
) -> IntValue<'ctx> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let indices_size = indices.get_type().len();
|
let indices_size = indices.get_type().len();
|
||||||
|
@ -841,7 +841,7 @@ pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||||
indices.get_type().get_element_type(),
|
indices.get_type().get_element_type(),
|
||||||
llvm_usize.const_int(indices_size as u64, false),
|
llvm_usize.const_int(indices_size as u64, false),
|
||||||
None
|
None
|
||||||
)?;
|
).unwrap();
|
||||||
for i in 0..indices_size {
|
for i in 0..indices_size {
|
||||||
let v = ctx.builder.build_extract_value(indices, i, "")
|
let v = ctx.builder.build_extract_value(indices, i, "")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
|
@ -408,6 +408,7 @@ pub struct CodeGenTask {
|
||||||
///
|
///
|
||||||
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
|
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
|
||||||
/// would be represented by an `i8`.
|
/// would be represented by an `i8`.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn get_llvm_type<'ctx>(
|
fn get_llvm_type<'ctx>(
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
|
@ -543,6 +544,7 @@ fn get_llvm_type<'ctx>(
|
||||||
/// ABI representation is that the in-memory representation must be at least byte-sized and must
|
/// ABI representation is that the in-memory representation must be at least byte-sized and must
|
||||||
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such
|
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such
|
||||||
/// restriction for ABI representations.
|
/// restriction for ABI representations.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn get_llvm_abi_type<'ctx>(
|
fn get_llvm_abi_type<'ctx>(
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
|
@ -809,7 +811,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
|
||||||
/* filename */
|
/* filename */
|
||||||
&task
|
&task
|
||||||
.body
|
.body
|
||||||
.get(0)
|
.first()
|
||||||
.map_or_else(
|
.map_or_else(
|
||||||
|| "<nac3_internal>".to_string(),
|
|| "<nac3_internal>".to_string(),
|
||||||
|f| f.location.file.0.to_string(),
|
|f| f.location.file.0.to_string(),
|
||||||
|
@ -839,7 +841,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
|
||||||
inkwell::debug_info::DIFlags::PUBLIC,
|
inkwell::debug_info::DIFlags::PUBLIC,
|
||||||
);
|
);
|
||||||
let (row, col) =
|
let (row, col) =
|
||||||
task.body.get(0).map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
|
task.body.first().map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
|
||||||
let func_scope: DISubprogram<'_> = dibuilder.create_function(
|
let func_scope: DISubprogram<'_> = dibuilder.create_function(
|
||||||
/* scope */ compile_unit.as_debug_info_scope(),
|
/* scope */ compile_unit.as_debug_info_scope(),
|
||||||
/* func name */ symbol,
|
/* func name */ symbol,
|
||||||
|
|
|
@ -55,7 +55,7 @@ pub fn gen_var<'ctx>(
|
||||||
Ok(ptr)
|
Ok(ptr)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See [CodeGenerator::gen_array_var_alloc].
|
/// See [`CodeGenerator::gen_array_var_alloc`].
|
||||||
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
|
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
ty: T,
|
ty: T,
|
||||||
|
@ -484,7 +484,7 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||||
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap();
|
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
|
||||||
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
||||||
// The BB containing the loop condition check
|
// The BB containing the loop condition check
|
||||||
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
||||||
|
|
|
@ -119,7 +119,6 @@ impl SymbolValue {
|
||||||
/// * `constant` - The constant to create the value from.
|
/// * `constant` - The constant to create the value from.
|
||||||
pub fn from_constant_inferred(
|
pub fn from_constant_inferred(
|
||||||
constant: &Constant,
|
constant: &Constant,
|
||||||
unifier: &mut Unifier
|
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
match constant {
|
match constant {
|
||||||
Constant::None => Ok(SymbolValue::OptionNone),
|
Constant::None => Ok(SymbolValue::OptionNone),
|
||||||
|
@ -140,7 +139,7 @@ impl SymbolValue {
|
||||||
Constant::Tuple(t) => {
|
Constant::Tuple(t) => {
|
||||||
let elems = t
|
let elems = t
|
||||||
.iter()
|
.iter()
|
||||||
.map(|constant| Self::from_constant_inferred(constant, unifier))
|
.map(Self::from_constant_inferred)
|
||||||
.collect::<Result<Vec<SymbolValue>, _>>()?;
|
.collect::<Result<Vec<SymbolValue>, _>>()?;
|
||||||
Ok(SymbolValue::Tuple(elems))
|
Ok(SymbolValue::Tuple(elems))
|
||||||
}
|
}
|
||||||
|
@ -507,7 +506,7 @@ pub fn parse_type_annotation<T>(
|
||||||
|
|
||||||
let values = if let Tuple { elts, .. } = &slice.node {
|
let values = if let Tuple { elts, .. } = &slice.node {
|
||||||
elts.iter()
|
elts.iter()
|
||||||
.map(|elt| parse_literal(elt))
|
.map(&mut parse_literal)
|
||||||
.collect::<Result<Vec<_>, _>>()?
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
} else {
|
} else {
|
||||||
vec![parse_literal(slice)?]
|
vec![parse_literal(slice)?]
|
||||||
|
@ -577,8 +576,8 @@ pub fn parse_type_annotation<T>(
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier)
|
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
|
||||||
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone())))
|
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
|
||||||
.map_err(|err| HashSet::from([err])),
|
.map_err(|err| HashSet::from([err])),
|
||||||
_ => Err(HashSet::from([
|
_ => Err(HashSet::from([
|
||||||
format!("unsupported type expression at {}", expr.location),
|
format!("unsupported type expression at {}", expr.location),
|
||||||
|
|
|
@ -18,6 +18,7 @@ use crate::{
|
||||||
gen_ndarray_empty,
|
gen_ndarray_empty,
|
||||||
gen_ndarray_eye,
|
gen_ndarray_eye,
|
||||||
gen_ndarray_full,
|
gen_ndarray_full,
|
||||||
|
gen_ndarray_identity,
|
||||||
gen_ndarray_ones,
|
gen_ndarray_ones,
|
||||||
gen_ndarray_zeros,
|
gen_ndarray_zeros,
|
||||||
},
|
},
|
||||||
|
@ -30,7 +31,6 @@ use inkwell::{
|
||||||
IntPredicate
|
IntPredicate
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
use crate::toplevel::numpy::gen_ndarray_identity;
|
|
||||||
|
|
||||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||||
|
|
||||||
|
@ -903,7 +903,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
// type variable
|
// type variable
|
||||||
&[(list_int32, "shape")],
|
&[(list_int32, "shape")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
gen_ndarray_empty(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -916,7 +916,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
// type variable
|
// type variable
|
||||||
&[(list_int32, "shape")],
|
&[(list_int32, "shape")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
gen_ndarray_empty(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -929,7 +929,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
// type variable
|
// type variable
|
||||||
&[(list_int32, "shape")],
|
&[(list_int32, "shape")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_zeros(ctx, obj, fun, args, generator)
|
gen_ndarray_zeros(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -942,7 +942,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
// type variable
|
// type variable
|
||||||
&[(list_int32, "shape")],
|
&[(list_int32, "shape")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_ones(ctx, obj, fun, args, generator)
|
gen_ndarray_ones(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -958,7 +958,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
// type variable
|
// type variable
|
||||||
&[(list_int32, "shape"), (tv, "fill_value")],
|
&[(list_int32, "shape"), (tv, "fill_value")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_full(ctx, obj, fun, args, generator)
|
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -980,13 +980,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
ret: ndarray_float_2d,
|
ret: ndarray_float_2d,
|
||||||
vars: var_map.clone(),
|
vars: var_map.clone(),
|
||||||
})),
|
})),
|
||||||
var_id: Default::default(),
|
var_id: Vec::default(),
|
||||||
instance_to_symbol: Default::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: Default::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_eye(ctx, obj, fun, args, generator)
|
gen_ndarray_eye(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
|
@ -999,7 +999,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
ndarray_float_2d,
|
ndarray_float_2d,
|
||||||
&[(int32, "n")],
|
&[(int32, "n")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_identity(ctx, obj, fun, args, generator)
|
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -1527,14 +1527,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
|
||||||
if len.get_type().get_bit_width() != 32 {
|
if len.get_type().get_bit_width() == 32 {
|
||||||
|
Some(len.into())
|
||||||
|
} else {
|
||||||
Some(ctx.builder
|
Some(ctx.builder
|
||||||
.build_int_truncate(len, llvm_i32, "len")
|
.build_int_truncate(len, llvm_i32, "len")
|
||||||
.map(Into::into)
|
.map(Into::into)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
)
|
)
|
||||||
} else {
|
|
||||||
Some(len.into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
|
|
|
@ -1350,9 +1350,9 @@ impl TopLevelComposer {
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assign { .. } => {}, // we don't class attributes
|
ast::StmtKind::Assign { .. } // we don't class attributes
|
||||||
ast::StmtKind::Pass { .. } => {}
|
| ast::StmtKind::Expr { value: _, .. } // typically a docstring; ignoring all expressions matches CPython behavior
|
||||||
ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior
|
| ast::StmtKind::Pass { .. } => {}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(HashSet::from([
|
return Err(HashSet::from([
|
||||||
format!(
|
format!(
|
||||||
|
|
|
@ -21,10 +21,10 @@ use crate::{
|
||||||
/// Creates an `NDArray` instance from a constant shape.
|
/// Creates an `NDArray` instance from a constant shape.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue].
|
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [`ArrayValue`].
|
||||||
fn create_ndarray_const_shape<'ctx, 'a>(
|
fn create_ndarray_const_shape<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ArrayValue<'ctx>
|
shape: ArrayValue<'ctx>
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
@ -94,9 +94,9 @@ fn create_ndarray_const_shape<'ctx, 'a>(
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_zero_value<'ctx, 'a>(
|
fn ndarray_zero_value<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> BasicValueEnum<'ctx> {
|
) -> BasicValueEnum<'ctx> {
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
@ -108,15 +108,15 @@ fn ndarray_zero_value<'ctx, 'a>(
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||||
ctx.ctx.bool_type().const_zero().into()
|
ctx.ctx.bool_type().const_zero().into()
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||||
ctx.gen_string(generator, "").into()
|
ctx.gen_string(generator, "")
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_one_value<'ctx, 'a>(
|
fn ndarray_one_value<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> BasicValueEnum<'ctx> {
|
) -> BasicValueEnum<'ctx> {
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
@ -130,7 +130,7 @@ fn ndarray_one_value<'ctx, 'a>(
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||||
ctx.ctx.bool_type().const_int(1, false).into()
|
ctx.ctx.bool_type().const_int(1, false).into()
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||||
ctx.gen_string(generator, "1").into()
|
ctx.gen_string(generator, "1")
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
@ -138,11 +138,11 @@ fn ndarray_one_value<'ctx, 'a>(
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
fn call_ndarray_empty_impl<'ctx, 'a>(
|
fn call_ndarray_empty_impl<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
@ -308,14 +308,14 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||||
///
|
///
|
||||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||||
/// with the given value (as opposed to all elements within the array).
|
/// with the given value (as opposed to all elements within the array).
|
||||||
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
fn ndarray_fill_indexed<'ctx, ValueFn>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
value_fn: ValueFn,
|
value_fn: ValueFn,
|
||||||
) -> Result<(), String>
|
) -> Result<(), String>
|
||||||
where
|
where
|
||||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
{
|
{
|
||||||
ndarray_fill_flattened(
|
ndarray_fill_flattened(
|
||||||
generator,
|
generator,
|
||||||
|
@ -327,7 +327,7 @@ fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
||||||
ctx,
|
ctx,
|
||||||
idx,
|
idx,
|
||||||
ndarray,
|
ndarray,
|
||||||
)?;
|
);
|
||||||
|
|
||||||
value_fn(generator, ctx, indices)
|
value_fn(generator, ctx, indices)
|
||||||
}
|
}
|
||||||
|
@ -336,11 +336,11 @@ fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
fn call_ndarray_zeros_impl<'ctx, 'a>(
|
fn call_ndarray_zeros_impl<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
@ -372,11 +372,11 @@ fn call_ndarray_zeros_impl<'ctx, 'a>(
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
fn call_ndarray_ones_impl<'ctx, 'a>(
|
fn call_ndarray_ones_impl<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
@ -408,11 +408,11 @@ fn call_ndarray_ones_impl<'ctx, 'a>(
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
fn call_ndarray_full_impl<'ctx, 'a>(
|
fn call_ndarray_full_impl<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
@ -465,7 +465,7 @@ fn call_ndarray_full_impl<'ctx, 'a>(
|
||||||
|
|
||||||
copy.into()
|
copy.into()
|
||||||
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
||||||
fill_value.into()
|
fill_value
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -479,10 +479,10 @@ fn call_ndarray_full_impl<'ctx, 'a>(
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
fn call_ndarray_eye_impl<'ctx, 'a>(
|
fn call_ndarray_eye_impl<'ctx>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
nrows: IntValue<'ctx>,
|
nrows: IntValue<'ctx>,
|
||||||
ncols: IntValue<'ctx>,
|
ncols: IntValue<'ctx>,
|
||||||
|
@ -552,11 +552,11 @@ fn call_ndarray_eye_impl<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx, 'a>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
@ -576,11 +576,11 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.zeros`.
|
/// Generates LLVM IR for `ndarray.zeros`.
|
||||||
pub fn gen_ndarray_zeros<'ctx, 'a>(
|
pub fn gen_ndarray_zeros<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
@ -600,11 +600,11 @@ pub fn gen_ndarray_zeros<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.ones`.
|
/// Generates LLVM IR for `ndarray.ones`.
|
||||||
pub fn gen_ndarray_ones<'ctx, 'a>(
|
pub fn gen_ndarray_ones<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
@ -624,11 +624,11 @@ pub fn gen_ndarray_ones<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.full`.
|
/// Generates LLVM IR for `ndarray.full`.
|
||||||
pub fn gen_ndarray_full<'ctx, 'a>(
|
pub fn gen_ndarray_full<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
@ -652,11 +652,11 @@ pub fn gen_ndarray_full<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.eye`.
|
/// Generates LLVM IR for `ndarray.eye`.
|
||||||
pub fn gen_ndarray_eye<'ctx, 'a>(
|
pub fn gen_ndarray_eye<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
@ -668,7 +668,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
|
||||||
|
|
||||||
let ncols_ty = fun.0.args[1].ty;
|
let ncols_ty = fun.0.args[1].ty;
|
||||||
let ncols_arg = args.iter()
|
let ncols_arg = args.iter()
|
||||||
.find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false))
|
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||||
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
|
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
||||||
|
@ -676,7 +676,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
|
||||||
|
|
||||||
let offset_ty = fun.0.args[2].ty;
|
let offset_ty = fun.0.args[2].ty;
|
||||||
let offset_arg = args.iter()
|
let offset_arg = args.iter()
|
||||||
.find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false))
|
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
||||||
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
|
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
Ok(context.gen_symbol_val(
|
Ok(context.gen_symbol_val(
|
||||||
|
@ -697,11 +697,11 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.identity`.
|
/// Generates LLVM IR for `ndarray.identity`.
|
||||||
pub fn gen_ndarray_identity<'ctx, 'a>(
|
pub fn gen_ndarray_identity<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
|
|
|
@ -519,7 +519,7 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
||||||
TypeAnnotation::Literal(values) => {
|
TypeAnnotation::Literal(values) => {
|
||||||
let values = values.iter()
|
let values = values.iter()
|
||||||
.map(|v| SymbolValue::from_constant_inferred(v, unifier))
|
.map(SymbolValue::from_constant_inferred)
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|err| HashSet::from([err]))?;
|
.map_err(|err| HashSet::from([err]))?;
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ pub struct PrimitiveStore {
|
||||||
|
|
||||||
impl PrimitiveStore {
|
impl PrimitiveStore {
|
||||||
/// Returns a [Type] representing `size_t`.
|
/// Returns a [Type] representing `size_t`.
|
||||||
|
#[must_use]
|
||||||
pub fn usize(&self) -> Type {
|
pub fn usize(&self) -> Type {
|
||||||
match self.size_t {
|
match self.size_t {
|
||||||
32 => self.uint32,
|
32 => self.uint32,
|
||||||
|
@ -1074,6 +1075,7 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
|
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::unnecessary_wraps)]
|
||||||
fn infer_identifier(&mut self, id: StrRef) -> InferenceResult {
|
fn infer_identifier(&mut self, id: StrRef) -> InferenceResult {
|
||||||
Ok(if let Some(ty) = self.variable_mapping.get(&id) {
|
Ok(if let Some(ty) = self.variable_mapping.get(&id) {
|
||||||
*ty
|
*ty
|
||||||
|
@ -1126,6 +1128,7 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
|
Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::unnecessary_wraps)]
|
||||||
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||||
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
|
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
|
||||||
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
||||||
|
@ -1242,18 +1245,18 @@ impl<'a> Inferencer<'a> {
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
value: &ast::Expr<Option<Type>>,
|
||||||
dummy_tvar: Type,
|
dummy_tvar: Type,
|
||||||
ndims: &Type,
|
ndims: Type,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
debug_assert!(matches!(
|
debug_assert!(matches!(
|
||||||
&*self.unifier.get_ty_immutable(dummy_tvar),
|
&*self.unifier.get_ty_immutable(dummy_tvar),
|
||||||
TypeEnum::TVar { is_const_generic: false, .. }
|
TypeEnum::TVar { is_const_generic: false, .. }
|
||||||
));
|
));
|
||||||
|
|
||||||
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims });
|
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims });
|
||||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
||||||
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
||||||
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims))
|
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(ndims))
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndims = values.iter()
|
let ndims = values.iter()
|
||||||
|
@ -1320,7 +1323,7 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, ty, *ndims)
|
||||||
} else {
|
} else {
|
||||||
// the index is a constant, so value can be a sequence.
|
// the index is a constant, so value can be a sequence.
|
||||||
let ind: Option<i32> = (*val).try_into().ok();
|
let ind: Option<i32> = (*val).try_into().ok();
|
||||||
|
@ -1350,7 +1353,7 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { ndims, .. } => {
|
TypeEnum::TNDArray { ndims, .. } => {
|
||||||
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, ty, *ndims)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -206,7 +206,7 @@ impl TypeEnum {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a [TypeEnum] representing a generic `ndarray` type.
|
/// Returns a [`TypeEnum`] representing a generic `ndarray` type.
|
||||||
///
|
///
|
||||||
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
|
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
|
||||||
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
|
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
|
||||||
|
@ -262,13 +262,13 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the [PrimitiveStore] instance within this `Unifier`.
|
/// Sets the [`PrimitiveStore`] instance within this `Unifier`.
|
||||||
///
|
///
|
||||||
/// This function can only be invoked once. Any subsequent invocations will result in an
|
/// This function can only be invoked once. Any subsequent invocations will result in an
|
||||||
/// assertion error..
|
/// assertion error.
|
||||||
pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) {
|
pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) {
|
||||||
assert!(self.primitive_store.is_none());
|
assert!(self.primitive_store.is_none());
|
||||||
self.primitive_store.replace(primitives.clone());
|
self.primitive_store.replace(*primitives);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> {
|
pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> {
|
||||||
|
@ -496,18 +496,22 @@ impl Unifier {
|
||||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||||
use TypeEnum::*;
|
use TypeEnum::*;
|
||||||
match &*self.get_ty(a) {
|
match &*self.get_ty(a) {
|
||||||
TRigidVar { .. } | TLiteral { .. } => true,
|
TRigidVar { .. }
|
||||||
|
| TLiteral { .. }
|
||||||
|
// functions are instantiated for each call sites, so the function type can contain
|
||||||
|
// type variables.
|
||||||
|
| TFunc { .. } => true,
|
||||||
|
|
||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
TList { ty }
|
||||||
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
| TVirtual { ty }
|
||||||
|
| TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
|
||||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
TObj { params: vars, .. } => {
|
TObj { params: vars, .. } => {
|
||||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||||
}
|
}
|
||||||
// functions are instantiated for each call sites, so the function type can contain
|
|
||||||
// type variables.
|
|
||||||
TFunc { .. } => true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -748,8 +752,7 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) |
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty } | TNDArray { ty, .. }) => {
|
||||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => {
|
|
||||||
for (k, v) in fields {
|
for (k, v) in fields {
|
||||||
match *k {
|
match *k {
|
||||||
RecordKey::Int(_) => {
|
RecordKey::Int(_) => {
|
||||||
|
@ -795,7 +798,7 @@ impl Unifier {
|
||||||
SymbolValue::I64(v) => v as i128,
|
SymbolValue::I64(v) => v as i128,
|
||||||
SymbolValue::U32(v) => v as i128,
|
SymbolValue::U32(v) => v as i128,
|
||||||
SymbolValue::U64(v) => v as i128,
|
SymbolValue::U64(v) => v as i128,
|
||||||
_ => return self.incompatible_types(a, b),
|
_ => return Self::incompatible_types(a, b),
|
||||||
};
|
};
|
||||||
|
|
||||||
let can_convert = if self.unioned(ty, primitives.int32) {
|
let can_convert = if self.unioned(ty, primitives.int32) {
|
||||||
|
@ -811,7 +814,7 @@ impl Unifier {
|
||||||
};
|
};
|
||||||
|
|
||||||
if !can_convert {
|
if !can_convert {
|
||||||
return self.incompatible_types(a, b)
|
return Self::incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -836,7 +839,7 @@ impl Unifier {
|
||||||
let v2i = symbol_value_to_int(v2);
|
let v2i = symbol_value_to_int(v2);
|
||||||
|
|
||||||
if v1i != v2i {
|
if v1i != v2i {
|
||||||
return self.incompatible_types(a, b)
|
return Self::incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -863,10 +866,10 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
||||||
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
||||||
return self.incompatible_types(a, b)
|
return Self::incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
if self.unify_impl(*ndims1, *ndims2, false).is_err() {
|
if self.unify_impl(*ndims1, *ndims2, false).is_err() {
|
||||||
return self.incompatible_types(a, b)
|
return Self::incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
|
@ -945,7 +948,7 @@ impl Unifier {
|
||||||
TObj { obj_id: id2, params: params2, .. },
|
TObj { obj_id: id2, params: params2, .. },
|
||||||
) => {
|
) => {
|
||||||
if id1 != id2 {
|
if id1 != id2 {
|
||||||
self.incompatible_types(a, b)?;
|
Self::incompatible_types(a, b)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
|
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
|
||||||
|
@ -1015,7 +1018,7 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if swapped {
|
if swapped {
|
||||||
return self.incompatible_types(a, b);
|
return Self::incompatible_types(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.unify_impl(b, a, true)?;
|
self.unify_impl(b, a, true)?;
|
||||||
|
@ -1179,7 +1182,7 @@ impl Unifier {
|
||||||
table.set_value(a, ty_b);
|
table.set_value(a, ty_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
|
fn incompatible_types(a: Type, b: Type) -> Result<(), TypeError> {
|
||||||
Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))
|
Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1326,7 +1329,7 @@ impl Unifier {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
TypeEnum::TCall(_) => {
|
||||||
unreachable!("{} not expected", ty.get_type_name())
|
unreachable!("{} not expected", ty.get_type_name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue