[core] codegen: Refactor to use CodeGenContext::get_size_type

Simplifies a lot of API usage.
This commit is contained in:
David Mak 2025-01-13 21:05:27 +08:00
parent c59fd286ff
commit bd66fe48d8
35 changed files with 176 additions and 266 deletions

View File

@ -471,7 +471,7 @@ fn format_rpc_arg<'ctx>(
// libproto_artiq: NDArray = [data[..], dim_sz[..]] // libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
@ -556,7 +556,7 @@ fn format_rpc_ret<'ctx>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
@ -697,7 +697,7 @@ fn format_rpc_ret<'ctx>(
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let num_elements = ndarray.size(generator, ctx); let num_elements = ndarray.size(ctx);
let expected_ndarray_nbytes = let expected_ndarray_nbytes =
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
@ -809,7 +809,7 @@ fn rpc_codegen_callback_fn<'ctx>(
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let int8 = ctx.ctx.i8_type(); let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx); let size_type = ctx.get_size_type();
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
@ -1167,7 +1167,7 @@ fn polymorphic_print<'ctx>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let suffix = suffix.unwrap_or_default(); let suffix = suffix.unwrap_or_default();

View File

@ -1007,7 +1007,7 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let size_t = generator.get_size_type(ctx.ctx); let size_t = ctx.get_size_type();
let ty = if len == 0 let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{ {
@ -1096,7 +1096,7 @@ impl InnerResolver {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
let dtype = llvm_ndarray.element_type(); let dtype = llvm_ndarray.element_type();

View File

@ -64,7 +64,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
.map_value(arg.into_pointer_value(), None); .map_value(arg.into_pointer_value(), None);
ctx.builder ctx.builder
.build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len")
.unwrap() .unwrap()
} }
@ -835,7 +835,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
let llvm_int64 = ctx.ctx.i64_type(); let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
@ -870,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let size_nez = ctx let size_nez = ctx
.builder .builder
.build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") .build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "")
.unwrap(); .unwrap();
ctx.make_assert( ctx.make_assert(
@ -1676,7 +1676,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr"; const FN_NAME: &str = "np_linalg_qr";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1728,7 +1728,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd"; const FN_NAME: &str = "np_linalg_svd";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1821,7 +1821,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv"; const FN_NAME: &str = "np_linalg_pinv";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1862,7 +1862,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu"; const FN_NAME: &str = "sp_linalg_lu";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1915,7 +1915,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { let BasicValueEnum::PointerValue(x1) = x1 else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1968,7 +1968,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };

View File

@ -165,7 +165,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
.build_global_string_ptr(v, "const") .build_global_string_ptr(v, "const")
.map(|v| v.as_pointer_value().into()) .map(|v| v.as_pointer_value().into())
.unwrap(); .unwrap();
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let size = self.get_size_type().const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type();
ty.const_named_struct(&[str_ptr, size.into()]).into() ty.const_named_struct(&[str_ptr, size.into()]).into()
} }
@ -318,7 +318,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
.build_global_string_ptr(v, "const") .build_global_string_ptr(v, "const")
.map(|v| v.as_pointer_value().into()) .map(|v| v.as_pointer_value().into())
.unwrap(); .unwrap();
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let size = self.get_size_type().const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str); let ty = self.get_llvm_type(generator, self.primitives.str);
let val = let val =
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
@ -820,7 +820,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
let id; let id;
@ -1020,7 +1020,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
} }
let is_vararg = args.iter().any(|arg| arg.is_vararg); let is_vararg = args.iter().any(|arg| arg.is_vararg);
if is_vararg { if is_vararg {
params.push(generator.get_size_type(ctx.ctx).into()); params.push(ctx.get_size_type().into());
} }
let fun_ty = match ret_type { let fun_ty = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg), Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg),
@ -1128,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_t = generator.get_size_type(ctx.ctx); let size_t = ctx.get_size_type();
let zero_size_t = size_t.const_zero(); let zero_size_t = size_t.const_zero();
let zero_32 = int32.const_zero(); let zero_32 = int32.const_zero();
@ -1258,12 +1258,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`
let emit_cont_bb = let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| {
|ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| {
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
list.store_size( list.store_size(
ctx, ctx,
generator,
ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(),
); );
}; };
@ -1274,7 +1272,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} else { } else {
// Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the
// no element matches the predicate // no element matches the predicate
emit_cont_bb(ctx, generator, list); emit_cont_bb(ctx, list);
return Ok(None); return Ok(None);
}; };
@ -1287,7 +1285,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
let Some(elem) = generator.gen_expr(ctx, elt)? else { let Some(elem) = generator.gen_expr(ctx, elt)? else {
// Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents
emit_cont_bb(ctx, generator, list); emit_cont_bb(ctx, list);
return Ok(None); return Ok(None);
}; };
@ -1304,7 +1302,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.unwrap(); .unwrap();
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
emit_cont_bb(ctx, generator, list); emit_cont_bb(ctx, list);
Ok(Some(list.as_base_value().into())) Ok(Some(list.as_base_value().into()))
} }
@ -1350,7 +1348,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
if op.variant == BinopVariant::AugAssign { if op.variant == BinopVariant::AugAssign {
todo!("Augmented assignment operators not implemented for lists") todo!("Augmented assignment operators not implemented for lists")
@ -1972,7 +1970,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let rhs = rhs.into_struct_value(); let rhs = rhs.into_struct_value();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(plhs, lhs).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap();
@ -2000,7 +1998,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None, None,
).into_int_value(); ).into_int_value();
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
if *op == Cmpop::NotEq { if *op == Cmpop::NotEq {
ctx.builder.build_not(result, "").unwrap() ctx.builder.build_not(result, "").unwrap()
} else { } else {
@ -2010,7 +2008,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
.iter() .iter()
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let gen_list_cmpop = |generator: &mut G, let gen_list_cmpop = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>| ctx: &mut CodeGenContext<'ctx, '_>|
@ -2375,7 +2373,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
ctx.current_loc = expr.location; ctx.current_loc = expr.location;
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let usize = generator.get_size_type(ctx.ctx); let usize = ctx.get_size_type();
let zero = int32.const_int(0, false); let zero = int32.const_int(0, false);
let loc = ctx.debug_info.0.create_debug_location( let loc = ctx.debug_info.0.create_debug_location(
@ -2480,7 +2478,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
Some(elements[0].get_type()) Some(elements[0].get_type())
}; };
let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); let length = ctx.get_size_type().const_int(elements.len() as u64, false);
let arr_str_ptr = if let Some(ty) = ty { let arr_str_ptr = if let Some(ty) = ty {
ListType::new(generator, ctx.ctx, ty).construct( ListType::new(generator, ctx.ctx, ty).construct(
generator, generator,
@ -3009,7 +3007,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}; };
let raw_index = ctx let raw_index = ctx
.builder .builder
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") .build_int_s_extend(raw_index, ctx.get_size_type(), "sext")
.unwrap(); .unwrap();
// handle negative index // handle negative index
let is_negative = ctx let is_negative = ctx
@ -3017,7 +3015,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
raw_index, raw_index,
generator.get_size_type(ctx.ctx).const_zero(), ctx.get_size_type().const_zero(),
"is_neg", "is_neg",
) )
.unwrap(); .unwrap();

View File

@ -24,7 +24,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
src_arr: ListValue<'ctx>, src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
@ -168,7 +168,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let new_len = let new_len =
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len); dest_arr.store_size(ctx, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }

View File

@ -68,13 +68,9 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
#[must_use] #[must_use]
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>( pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String {
generator: &G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned(); let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() { match ctx.get_size_type().get_bit_width() {
32 => {} 32 => {}
64 => name.push_str("64"), 64 => name.push_str("64"),
bit_width => { bit_width => {

View File

@ -21,7 +21,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
ndims: IntValue<'ctx>, ndims: IntValue<'ctx>,
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
assert_eq!(ndims.get_type(), llvm_usize); assert_eq!(ndims.get_type(), llvm_usize);
assert_eq!( assert_eq!(
@ -29,11 +29,8 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name( let name =
generator, get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape");
ctx,
"__nac3_ndarray_array_set_and_validate_list_shape",
);
infer_and_call_function( infer_and_call_function(
ctx, ctx,
@ -55,19 +52,14 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
/// - `ndarray.ndims`: Must be initialized. /// - `ndarray.ndims`: Must be initialized.
/// - `ndarray.shape`: Must be initialized. /// - `ndarray.shape`: Must be initialized.
/// - `ndarray.data`: Must be allocated and contiguous. /// - `ndarray.data`: Must be allocated and contiguous.
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
list: ListValue<'ctx>, list: ListValue<'ctx>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) { ) {
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
let name = get_usize_dependent_function_name( let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array");
generator,
ctx,
"__nac3_ndarray_array_write_list_to_array",
);
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -28,11 +28,8 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name( let name =
generator, get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -57,7 +54,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -69,11 +66,8 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name( let name =
generator, get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same");
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -94,15 +88,14 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
/// `ndarray`, corresponding to the value of `ndarray.size`. /// `ndarray`, corresponding to the value of `ndarray.size`.
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_size<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -120,15 +113,14 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. /// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_nbytes<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -146,15 +138,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of /// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
/// the `ndarray`, corresponding to the value of `ndarray.__len__`. /// the `ndarray`, corresponding to the value of `ndarray.__len__`.
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_len<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -171,15 +162,14 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
/// Generates a call to `__nac3_ndarray_is_c_contiguous`. /// Generates a call to `__nac3_ndarray_is_c_contiguous`.
/// ///
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. /// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -196,20 +186,19 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
/// Generates a call to `__nac3_ndarray_get_nth_pelement`. /// Generates a call to `__nac3_ndarray_get_nth_pelement`.
/// ///
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. /// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
index: IntValue<'ctx>, index: IntValue<'ctx>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
assert_eq!(index.get_type(), llvm_usize); assert_eq!(index.get_type(), llvm_usize);
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -236,7 +225,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
@ -245,8 +234,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
llvm_usize.into() llvm_usize.into()
); );
let name = let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices");
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -266,15 +254,13 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. /// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
/// ///
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. /// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) { ) {
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape");
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -291,13 +277,12 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number /// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in /// of elements in `src_ndarray` must be greater than or equal to the number of elements in
/// `dst_ndarray`. /// `dst_ndarray`.
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_copy_data<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data");
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -20,13 +20,12 @@ use crate::codegen::{
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. /// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. /// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. /// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_broadcast_to<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
@ -53,7 +52,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
assert_eq!(num_shape_entries.get_type(), llvm_usize); assert_eq!(num_shape_entries.get_type(), llvm_usize);
assert!(ShapeEntryType::is_type( assert!(ShapeEntryType::is_type(
@ -65,7 +64,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
assert_eq!(dst_ndims.get_type(), llvm_usize); assert_eq!(dst_ndims.get_type(), llvm_usize);
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -17,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -25,7 +25,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -33,7 +33,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -53,12 +53,11 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` /// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
/// object. /// object.
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_nditer_has_element<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>, iter: NDIterValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
@ -75,12 +74,8 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
/// Generates a call to `__nac3_nditer_next`. /// Generates a call to `__nac3_nditer_next`.
/// ///
/// Moves `iter` to point to the next element. /// Moves `iter` to point to the next element.
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) {
generator: &G, let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next");
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next");
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
} }

View File

@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
assert_eq!( assert_eq!(
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
@ -43,8 +43,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
llvm_usize.into() llvm_usize.into()
); );
let name = let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -18,14 +18,13 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera
new_ndims: IntValue<'ctx>, new_ndims: IntValue<'ctx>,
new_shape: ArraySliceValue<'ctx>, new_shape: ArraySliceValue<'ctx>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
assert_eq!(size.get_type(), llvm_usize); assert_eq!(size.get_type(), llvm_usize);
assert_eq!(new_ndims.get_type(), llvm_usize); assert_eq!(new_ndims.get_type(), llvm_usize);
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
let name = get_usize_dependent_function_name( let name = get_usize_dependent_function_name(
generator,
ctx, ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape", "__nac3_ndarray_reshape_resolve_and_check_new_shape",
); );

View File

@ -23,12 +23,12 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
) { ) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -2,11 +2,10 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
use itertools::Either; use itertools::Either;
use super::get_usize_dependent_function_name; use super::get_usize_dependent_function_name;
use crate::codegen::{CodeGenContext, CodeGenerator}; use crate::codegen::CodeGenContext;
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( pub fn call_string_eq<'ctx>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
str1_ptr: PointerValue<'ctx>, str1_ptr: PointerValue<'ctx>,
str1_len: IntValue<'ctx>, str1_len: IntValue<'ctx>,
@ -15,7 +14,7 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq"); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { let func = ctx.module.get_function(&func_name).unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(

View File

@ -1212,7 +1212,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let align_ty = align_ty.into(); let align_ty = align_ty.into();
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap(); let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();

View File

@ -207,7 +207,7 @@ pub fn gen_ndarray_eye<'ctx>(
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = context.get_size_type();
let llvm_dtype = context.get_llvm_type(generator, dtype); let llvm_dtype = context.get_llvm_type(generator, dtype);
let nrows = context let nrows = context
@ -244,7 +244,7 @@ pub fn gen_ndarray_identity<'ctx>(
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = context.get_size_type();
let llvm_dtype = context.get_llvm_type(generator, dtype); let llvm_dtype = context.get_llvm_type(generator, dtype);
let n = context let n = context
@ -325,8 +325,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
// Check shapes. // Check shapes.
let a_size = a.size(generator, ctx); let a_size = a.size(ctx);
let b_size = b.size(generator, ctx); let b_size = b.size(ctx);
let same_shape = let same_shape =
ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap();
ctx.make_assert( ctx.make_assert(
@ -353,9 +353,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b);
Ok((a_iter, b_iter)) Ok((a_iter, b_iter))
}, },
|generator, ctx, (a_iter, _b_iter)| { |_, ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status. // Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_element(generator, ctx)) Ok(a_iter.has_element(ctx))
}, },
|_, ctx, _hooks, (a_iter, b_iter)| { |_, ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(ctx); let a_scalar = a_iter.get_scalar(ctx);
@ -385,9 +385,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_store(result, new_result).unwrap(); ctx.builder.build_store(result, new_result).unwrap();
Ok(()) Ok(())
}, },
|generator, ctx, (a_iter, b_iter)| { |_, ctx, (a_iter, b_iter)| {
a_iter.next(generator, ctx); a_iter.next(ctx);
b_iter.next(generator, ctx); b_iter.next(ctx);
Ok(()) Ok(())
}, },
) )

View File

@ -306,7 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{ {
// Handle list item assignment // Handle list item assignment
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
let target = generator let target = generator
@ -367,10 +367,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, key_ty)? .to_basic_value_enum(ctx, generator, key_ty)?
.into_int_value(); .into_int_value();
let index = ctx let index =
.builder ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap();
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index // handle negative index
let is_negative = ctx let is_negative = ctx
@ -378,7 +376,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
index, index,
generator.get_size_type(ctx.ctx).const_zero(), ctx.get_size_type().const_zero(),
"is_neg", "is_neg",
) )
.unwrap(); .unwrap();
@ -460,7 +458,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let target = broadcast_result.ndarrays[0]; let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1]; let value = broadcast_result.ndarrays[1];
target.copy_data_from(generator, ctx, value); target.copy_data_from(ctx, value);
} }
_ => { _ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
@ -484,7 +482,7 @@ pub fn gen_for<G: CodeGenerator>(
let var_assignment = ctx.var_assignment.clone(); let var_assignment = ctx.var_assignment.clone();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_t = generator.get_size_type(ctx.ctx); let size_t = ctx.get_size_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let body_bb = ctx.ctx.append_basic_block(current, "for.body");

View File

@ -152,7 +152,7 @@ impl<'ctx> ListType<'ctx> {
_ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)),
}; };
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
None None
} else { } else {
@ -273,7 +273,7 @@ impl<'ctx> ListType<'ctx> {
} }
let plist = self.alloca_var(generator, ctx, name); let plist = self.alloca_var(generator, ctx, name);
plist.store_size(ctx, generator, len); plist.store_size(ctx, len);
let item = self.item.unwrap_or(self.llvm_usize.into()); let item = self.item.unwrap_or(self.llvm_usize.into());
plist.create_data(ctx, item, None); plist.create_data(ctx, item, None);
@ -300,7 +300,7 @@ impl<'ctx> ListType<'ctx> {
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
let plist = self.alloca_var(generator, ctx, name); let plist = self.alloca_var(generator, ctx, name);
plist.store_size(ctx, generator, self.llvm_usize.const_zero()); plist.store_size(ctx, self.llvm_usize.const_zero());
plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None);
plist plist

View File

@ -67,9 +67,7 @@ impl<'ctx> NDArrayType<'ctx> {
unsafe { ndarray.create_data(generator, ctx) }; unsafe { ndarray.create_data(generator, ctx) };
// Copy all contents from the list. // Copy all contents from the list.
irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray);
generator, ctx, list_value, ndarray,
);
ndarray ndarray
} }
@ -116,7 +114,7 @@ impl<'ctx> NDArrayType<'ctx> {
} }
// Set strides, the `data` is contiguous // Set strides, the `data` is contiguous
ndarray.set_strides_contiguous(generator, ctx); ndarray.set_strides_contiguous(ctx);
ndarray ndarray
} else { } else {

View File

@ -140,7 +140,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype); let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
} }

View File

@ -86,10 +86,10 @@ impl<'ctx> NDArrayType<'ctx> {
.collect_vec(); .collect_vec();
Ok((nditer, other_nditers)) Ok((nditer, other_nditers))
}, },
|generator, ctx, (out_nditer, _in_nditers)| { |_, ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_element()`. // We can simply use `out_nditer`'s `has_element()`.
// `in_nditers`' `has_element()`s should return the same value. // `in_nditers`' `has_element()`s should return the same value.
Ok(out_nditer.has_element(generator, ctx)) Ok(out_nditer.has_element(ctx))
}, },
|generator, ctx, _hooks, (out_nditer, in_nditers)| { |generator, ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
@ -104,10 +104,10 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(()) Ok(())
}, },
|generator, ctx, (out_nditer, in_nditers)| { |_, ctx, (out_nditer, in_nditers)| {
// Advance all iterators // Advance all iterators
out_nditer.next(generator, ctx); out_nditer.next(ctx);
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); in_nditers.iter().for_each(|nditer| nditer.next(ctx));
Ok(()) Ok(())
}, },
)?; )?;

View File

@ -158,7 +158,7 @@ impl<'ctx> NDArrayType<'ctx> {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype); let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
NDArrayType { NDArrayType {
@ -259,9 +259,9 @@ impl<'ctx> NDArrayType<'ctx> {
.builder .builder
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") .build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap(); .unwrap();
ndarray.store_itemsize(ctx, generator, itemsize); ndarray.store_itemsize(ctx, itemsize);
ndarray.store_ndims(ctx, generator, ndims); ndarray.store_ndims(ctx, ndims);
ndarray.create_shape(ctx, self.llvm_usize, ndims); ndarray.create_shape(ctx, self.llvm_usize, ndims);
ndarray.create_strides(ctx, self.llvm_usize, ndims); ndarray.create_strides(ctx, self.llvm_usize, ndims);
@ -307,7 +307,7 @@ impl<'ctx> NDArrayType<'ctx> {
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
.construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();
@ -342,7 +342,7 @@ impl<'ctx> NDArrayType<'ctx> {
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
.construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();

View File

@ -52,7 +52,7 @@ impl<'ctx> TupleType<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, ty: Type,
) -> Self { ) -> Self {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
// Sanity check on object type. // Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {

View File

@ -418,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); debug_assert_eq!(idx.get_type(), ctx.get_size_type());
let size = self.size(ctx, generator); let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();

View File

@ -97,13 +97,8 @@ impl<'ctx> ListValue<'ctx> {
} }
/// Stores the `size` of this `list` into this instance. /// Stores the `size` of this `list` into this instance.
pub fn store_size<G: CodeGenerator + ?Sized>( pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) {
&self, debug_assert_eq!(size.get_type(), ctx.get_size_type());
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
self.len_field(ctx).set(ctx, self.value, size, self.name); self.len_field(ctx).set(ctx, self.value, size, self.name);
} }
@ -213,7 +208,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); debug_assert_eq!(idx.get_type(), ctx.get_size_type());
let size = self.size(ctx, generator); let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();

View File

@ -112,7 +112,7 @@ impl<'ctx> NDArrayValue<'ctx> {
target_shape.base_ptr(ctx, generator), target_shape.base_ptr(ctx, generator),
); );
irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray);
broadcast_ndarray broadcast_ndarray
} }
} }
@ -146,7 +146,7 @@ fn broadcast_shapes<'ctx, G, Shape>(
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx);
assert!(in_shape_entries assert!(in_shape_entries
@ -199,7 +199,7 @@ impl<'ctx> NDArrayType<'ctx> {
) -> BroadcastAllResult<'ctx, G> { ) -> BroadcastAllResult<'ctx, G> {
assert!(!ndarrays.is_empty()); assert!(!ndarrays.is_empty());
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
// Infer the broadcast output ndims. // Infer the broadcast output ndims.
let broadcast_ndims_int = let broadcast_ndims_int =

View File

@ -130,7 +130,7 @@ impl<'ctx> NDArrayValue<'ctx> {
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), |_, ctx| Ok(self.is_c_contiguous(ctx)),
|_, ctx| { |_, ctx| {
// This ndarray is contiguous. // This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
@ -184,7 +184,7 @@ impl<'ctx> NDArrayValue<'ctx> {
// Copy shape and update strides // Copy shape and update strides
let shape = carray.load_shape(ctx); let shape = carray.load_shape(ctx);
ndarray.copy_shape_from_array(generator, ctx, shape); ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.set_strides_contiguous(generator, ctx); ndarray.set_strides_contiguous(ctx);
// Share data // Share data
let data = carray.load_data(ctx); let data = carray.load_data(ctx);

View File

@ -245,7 +245,7 @@ impl<'ctx> RustNDIndex<'ctx> {
} }
RustNDIndex::Slice(in_rust_slice) => { RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = let user_slice_ptr =
SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type())
.alloca_var(generator, ctx, None); .alloca_var(generator, ctx, None);
in_rust_slice.write_to_slice(ctx, user_slice_ptr); in_rust_slice.write_to_slice(ctx, user_slice_ptr);

View File

@ -35,7 +35,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
// Deduce ndims of the result of matmul. // Deduce ndims of the result of matmul.
@ -315,7 +315,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let result_shape = result.shape(); let result_shape = result.shape();
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
out_ndarray.copy_data_from(generator, ctx, result); out_ndarray.copy_data_from(ctx, result);
out_ndarray out_ndarray
} }
} }

View File

@ -81,13 +81,8 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Stores the number of dimensions `ndims` into this instance. /// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>( pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) {
&self, debug_assert_eq!(ndims.get_type(), ctx.get_size_type());
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx); let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap(); ctx.builder.build_store(pndims, ndims).unwrap();
@ -104,13 +99,8 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Stores the size of each element `itemsize` into this instance. /// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>( pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) {
&self, debug_assert_eq!(itemsize.get_type(), ctx.get_size_type());
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
itemsize: IntValue<'ctx>,
) {
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name);
} }
@ -205,12 +195,12 @@ impl<'ctx> NDArrayValue<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) { ) {
let nbytes = self.nbytes(generator, ctx); let nbytes = self.nbytes(ctx);
let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
self.store_data(ctx, data); self.store_data(ctx, data);
self.set_strides_contiguous(generator, ctx); self.set_strides_contiguous(ctx);
} }
/// Returns a proxy object to the field storing the data of this `NDArray`. /// Returns a proxy object to the field storing the data of this `NDArray`.
@ -284,52 +274,32 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Get the `np.size()` of this ndarray. /// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>( pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
&self, irrt::ndarray::call_nac3_ndarray_size(ctx, *self)
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
} }
/// Get the `ndarray.nbytes` of this ndarray. /// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>( pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
&self, irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self)
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
} }
/// Get the `len()` of this ndarray. /// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>( pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
&self, irrt::ndarray::call_nac3_ndarray_len(ctx, *self)
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
} }
/// Check if this ndarray is C-contiguous. /// Check if this ndarray is C-contiguous.
/// ///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags> /// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>( pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
&self, irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self)
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
} }
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
/// ///
/// Update the ndarray's strides to make the ndarray contiguous. /// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>( pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) {
&self, irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self);
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
} }
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and
@ -347,7 +317,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let shape = self.shape(); let shape = self.shape();
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { clone.create_data(generator, ctx) }; unsafe { clone.create_data(generator, ctx) };
clone.copy_data_from(generator, ctx, *self); clone.copy_data_from(ctx, *self);
clone clone
} }
@ -357,14 +327,9 @@ impl<'ctx> NDArrayValue<'ctx> {
/// do not matter. The copying order is determined by how their flattened views look. /// do not matter. The copying order is determined by how their flattened views look.
/// ///
/// Panics if the `dtype`s of ndarrays are different. /// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>( pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) {
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
) {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self);
} }
/// Fill the ndarray with a scalar. /// Fill the ndarray with a scalar.
@ -468,7 +433,7 @@ impl<'ctx> NDArrayValue<'ctx> {
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
if self.is_unsized() { if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible. // NOTE: `np.size(self) == 0` here is never possible.
let zero = generator.get_size_type(ctx.ctx).const_zero(); let zero = ctx.get_size_type().const_zero();
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
Some(value) Some(value)
@ -756,9 +721,9 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
generator: &G, _: &G,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0)
} }
} }
@ -770,7 +735,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx);
// Current implementation is transparent - The returned pointer type is // Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately // already cast into the expected type, allowing for immediately
@ -834,7 +799,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into());
let indices = TypedArrayLikeAdapter::from( let indices = TypedArrayLikeAdapter::from(
indices.as_slice_value(ctx, generator), indices.as_slice_value(ctx, generator),
@ -867,7 +832,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let indices_size = indices.size(ctx, generator); let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx let nidx_leq_ndims = ctx

View File

@ -53,20 +53,16 @@ impl<'ctx> NDIterValue<'ctx> {
/// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false. /// If `ndarray` is 0-sized, this always returns false.
#[must_use] #[must_use]
pub fn has_element<G: CodeGenerator + ?Sized>( pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
&self, irrt::ndarray::call_nac3_nditer_has_element(ctx, *self)
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self)
} }
/// Go to the next element. If `has_element()` is false, then this has undefined behavior. /// Go to the next element. If `has_element()` is false, then this has undefined behavior.
/// ///
/// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called. /// If `ndarray` is 0-sized, this can never be called.
pub fn next<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) { pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) {
irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); irrt::ndarray::call_nac3_nditer_next(ctx, *self);
} }
fn element_field( fn element_field(
@ -167,10 +163,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|generator, ctx| { |generator, ctx| {
Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self))
}, },
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)), |_, ctx, nditer| Ok(nditer.has_element(ctx)),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| { |_, ctx, nditer| {
nditer.next(generator, ctx); nditer.next(ctx);
Ok(()) Ok(())
}, },
) )

View File

@ -30,7 +30,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { ) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = ctx.get_size_type();
let zero = llvm_usize.const_zero(); let zero = llvm_usize.const_zero();
let one = llvm_usize.const_int(1, false); let one = llvm_usize.const_int(1, false);

View File

@ -70,7 +70,7 @@ impl<'ctx> NDArrayValue<'ctx> {
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
// Resolve negative indices // Resolve negative indices
let size = self.size(generator, ctx); let size = self.size(ctx);
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false);
let dst_shape = dst_ndarray.shape(); let dst_shape = dst_ndarray.shape();
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
@ -84,10 +84,10 @@ impl<'ctx> NDArrayValue<'ctx> {
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), |_, ctx| Ok(self.is_c_contiguous(ctx)),
|generator, ctx| { |generator, ctx| {
// Reshape is possible without copying // Reshape is possible without copying
dst_ndarray.set_strides_contiguous(generator, ctx); dst_ndarray.set_strides_contiguous(ctx);
dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator));
Ok(()) Ok(())
@ -97,7 +97,7 @@ impl<'ctx> NDArrayValue<'ctx> {
unsafe { unsafe {
dst_ndarray.create_data(generator, ctx); dst_ndarray.create_data(generator, ctx);
} }
dst_ndarray.copy_data_from(generator, ctx, *self); dst_ndarray.copy_data_from(ctx, *self);
Ok(()) Ok(())
}, },

View File

@ -1278,11 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
let size = ctx let size = ctx
.builder .builder
.build_int_truncate_or_bit_cast( .build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "")
ndarray.size(generator, ctx),
ctx.ctx.i32_type(),
"",
)
.unwrap(); .unwrap();
Ok(Some(size.into())) Ok(Some(size.into()))
}), }),