[core] codegen: Refactor to use CodeGenContext::get_size_type
Simplifies a lot of API usage.
This commit is contained in:
parent
c59fd286ff
commit
bd66fe48d8
@ -471,7 +471,7 @@ fn format_rpc_arg<'ctx>(
|
||||
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
@ -556,7 +556,7 @@ fn format_rpc_ret<'ctx>(
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
@ -697,7 +697,7 @@ fn format_rpc_ret<'ctx>(
|
||||
|
||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let num_elements = ndarray.size(generator, ctx);
|
||||
let num_elements = ndarray.size(ctx);
|
||||
|
||||
let expected_ndarray_nbytes =
|
||||
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
|
||||
@ -809,7 +809,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let int8 = ctx.ctx.i8_type();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let size_type = ctx.get_size_type();
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
@ -1167,7 +1167,7 @@ fn polymorphic_print<'ctx>(
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let suffix = suffix.unwrap_or_default();
|
||||
|
||||
|
@ -1007,7 +1007,7 @@ impl InnerResolver {
|
||||
}
|
||||
_ => unreachable!("must be list"),
|
||||
};
|
||||
let size_t = generator.get_size_type(ctx.ctx);
|
||||
let size_t = ctx.get_size_type();
|
||||
let ty = if len == 0
|
||||
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
|
||||
{
|
||||
@ -1096,7 +1096,7 @@ impl InnerResolver {
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
|
||||
let dtype = llvm_ndarray.element_type();
|
||||
|
||||
|
@ -64,7 +64,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
|
||||
.map_value(arg.into_pointer_value(), None);
|
||||
ctx.builder
|
||||
.build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len")
|
||||
.build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len")
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@ -835,7 +835,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
|
||||
|
||||
let llvm_int64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
@ -870,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let size_nez = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "")
|
||||
.build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
@ -1676,7 +1676,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_qr";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||
|
||||
@ -1728,7 +1728,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_svd";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||
|
||||
@ -1821,7 +1821,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_pinv";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||
|
||||
@ -1862,7 +1862,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_lu";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||
|
||||
@ -1915,7 +1915,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
@ -1968,7 +1968,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||
|
||||
|
@ -165,7 +165,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
|
||||
.build_global_string_ptr(v, "const")
|
||||
.map(|v| v.as_pointer_value().into())
|
||||
.unwrap();
|
||||
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
|
||||
let size = self.get_size_type().const_int(v.len() as u64, false);
|
||||
let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type();
|
||||
ty.const_named_struct(&[str_ptr, size.into()]).into()
|
||||
}
|
||||
@ -318,7 +318,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
|
||||
.build_global_string_ptr(v, "const")
|
||||
.map(|v| v.as_pointer_value().into())
|
||||
.unwrap();
|
||||
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
|
||||
let size = self.get_size_type().const_int(v.len() as u64, false);
|
||||
let ty = self.get_llvm_type(generator, self.primitives.str);
|
||||
let val =
|
||||
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
|
||||
@ -820,7 +820,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
|
||||
let id;
|
||||
@ -1020,7 +1020,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
let is_vararg = args.iter().any(|arg| arg.is_vararg);
|
||||
if is_vararg {
|
||||
params.push(generator.get_size_type(ctx.ctx).into());
|
||||
params.push(ctx.get_size_type().into());
|
||||
}
|
||||
let fun_ty = match ret_type {
|
||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg),
|
||||
@ -1128,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
return Ok(None);
|
||||
};
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_t = generator.get_size_type(ctx.ctx);
|
||||
let size_t = ctx.get_size_type();
|
||||
let zero_size_t = size_t.const_zero();
|
||||
let zero_32 = int32.const_zero();
|
||||
|
||||
@ -1258,12 +1258,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
|
||||
// Emits the content of `cont_bb`
|
||||
let emit_cont_bb =
|
||||
|ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| {
|
||||
let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| {
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
list.store_size(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(),
|
||||
);
|
||||
};
|
||||
@ -1274,7 +1272,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
} else {
|
||||
// Bail if the predicate is an ellipsis - Emit cont_bb contents in case the
|
||||
// no element matches the predicate
|
||||
emit_cont_bb(ctx, generator, list);
|
||||
emit_cont_bb(ctx, list);
|
||||
|
||||
return Ok(None);
|
||||
};
|
||||
@ -1287,7 +1285,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
|
||||
let Some(elem) = generator.gen_expr(ctx, elt)? else {
|
||||
// Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents
|
||||
emit_cont_bb(ctx, generator, list);
|
||||
emit_cont_bb(ctx, list);
|
||||
|
||||
return Ok(None);
|
||||
};
|
||||
@ -1304,7 +1302,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
.unwrap();
|
||||
ctx.builder.build_unconditional_branch(test_bb).unwrap();
|
||||
|
||||
emit_cont_bb(ctx, generator, list);
|
||||
emit_cont_bb(ctx, list);
|
||||
|
||||
Ok(Some(list.as_base_value().into()))
|
||||
}
|
||||
@ -1350,7 +1348,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
if op.variant == BinopVariant::AugAssign {
|
||||
todo!("Augmented assignment operators not implemented for lists")
|
||||
@ -1972,7 +1970,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let rhs = rhs.into_struct_value();
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||
@ -2000,7 +1998,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
||||
let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
||||
if *op == Cmpop::NotEq {
|
||||
ctx.builder.build_not(result, "").unwrap()
|
||||
} else {
|
||||
@ -2010,7 +2008,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
.iter()
|
||||
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let gen_list_cmpop = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||
@ -2375,7 +2373,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
ctx.current_loc = expr.location;
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let usize = generator.get_size_type(ctx.ctx);
|
||||
let usize = ctx.get_size_type();
|
||||
let zero = int32.const_int(0, false);
|
||||
|
||||
let loc = ctx.debug_info.0.create_debug_location(
|
||||
@ -2480,7 +2478,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
} else {
|
||||
Some(elements[0].get_type())
|
||||
};
|
||||
let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false);
|
||||
let length = ctx.get_size_type().const_int(elements.len() as u64, false);
|
||||
let arr_str_ptr = if let Some(ty) = ty {
|
||||
ListType::new(generator, ctx.ctx, ty).construct(
|
||||
generator,
|
||||
@ -3009,7 +3007,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
let raw_index = ctx
|
||||
.builder
|
||||
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
|
||||
.build_int_s_extend(raw_index, ctx.get_size_type(), "sext")
|
||||
.unwrap();
|
||||
// handle negative index
|
||||
let is_negative = ctx
|
||||
@ -3017,7 +3015,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
ctx.get_size_type().const_zero(),
|
||||
"is_neg",
|
||||
)
|
||||
.unwrap();
|
||||
|
@ -24,7 +24,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
src_arr: ListValue<'ctx>,
|
||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
@ -168,7 +168,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx.builder.position_at_end(update_bb);
|
||||
let new_len =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
|
||||
dest_arr.store_size(ctx, generator, new_len);
|
||||
dest_arr.store_size(ctx, new_len);
|
||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
}
|
||||
|
@ -68,13 +68,9 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
||||
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
||||
#[must_use]
|
||||
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
name: &str,
|
||||
) -> String {
|
||||
pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String {
|
||||
let mut name = name.to_owned();
|
||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
match ctx.get_size_type().get_bit_width() {
|
||||
32 => {}
|
||||
64 => name.push_str("64"),
|
||||
bit_width => {
|
||||
|
@ -21,7 +21,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
|
||||
ndims: IntValue<'ctx>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
||||
assert_eq!(ndims.get_type(), llvm_usize);
|
||||
assert_eq!(
|
||||
@ -29,11 +29,8 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
generator,
|
||||
ctx,
|
||||
"__nac3_ndarray_array_set_and_validate_list_shape",
|
||||
);
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
@ -55,19 +52,14 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
|
||||
/// - `ndarray.ndims`: Must be initialized.
|
||||
/// - `ndarray.shape`: Must be initialized.
|
||||
/// - `ndarray.data`: Must be allocated and contiguous.
|
||||
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
list: ListValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
generator,
|
||||
ctx,
|
||||
"__nac3_ndarray_array_write_list_to_array",
|
||||
);
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
|
@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
@ -28,11 +28,8 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
generator,
|
||||
ctx,
|
||||
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||
);
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -57,7 +54,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
|
||||
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
@ -69,11 +66,8 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
generator,
|
||||
ctx,
|
||||
"__nac3_ndarray_util_assert_output_shape_same",
|
||||
);
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -94,15 +88,14 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
|
||||
/// `ndarray`, corresponding to the value of `ndarray.size`.
|
||||
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_size<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -120,15 +113,14 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
|
||||
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
|
||||
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_nbytes<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -146,15 +138,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
|
||||
/// the `ndarray`, corresponding to the value of `ndarray.__len__`.
|
||||
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_len<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -171,15 +162,14 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Generates a call to `__nac3_ndarray_is_c_contiguous`.
|
||||
///
|
||||
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
|
||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -196,20 +186,19 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Generates a call to `__nac3_ndarray_get_nth_pelement`.
|
||||
///
|
||||
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
|
||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
index: IntValue<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
assert_eq!(index.get_type(), llvm_usize);
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -236,7 +225,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
@ -245,8 +234,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -266,15 +254,13 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
|
||||
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
|
||||
///
|
||||
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
|
||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -291,13 +277,12 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
|
||||
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in
|
||||
/// `dst_ndarray`.
|
||||
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_copy_data<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
|
@ -20,13 +20,12 @@ use crate::codegen::{
|
||||
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
|
||||
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
|
||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
||||
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_ndarray_broadcast_to<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
@ -53,7 +52,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(num_shape_entries.get_type(), llvm_usize);
|
||||
assert!(ShapeEntryType::is_type(
|
||||
@ -65,7 +64,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
||||
assert_eq!(dst_ndims.get_type(), llvm_usize);
|
||||
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
|
@ -17,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
|
@ -25,7 +25,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
@ -33,7 +33,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
@ -53,12 +53,11 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
///
|
||||
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
|
||||
/// object.
|
||||
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_nac3_nditer_has_element<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
iter: NDIterValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
@ -75,12 +74,8 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Generates a call to `__nac3_nditer_next`.
|
||||
///
|
||||
/// Moves `iter` to point to the next element.
|
||||
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
iter: NDIterValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
||||
pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next");
|
||||
|
||||
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
|
||||
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
|
||||
@ -43,8 +43,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
|
@ -18,14 +18,13 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera
|
||||
new_ndims: IntValue<'ctx>,
|
||||
new_shape: ArraySliceValue<'ctx>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(size.get_type(), llvm_usize);
|
||||
assert_eq!(new_ndims.get_type(), llvm_usize);
|
||||
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
generator,
|
||||
ctx,
|
||||
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
|
||||
);
|
||||
|
@ -23,12 +23,12 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
|
||||
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
|
@ -2,11 +2,10 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
|
||||
use itertools::Either;
|
||||
|
||||
use super::get_usize_dependent_function_name;
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
pub fn call_string_eq<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
str1_ptr: PointerValue<'ctx>,
|
||||
str1_len: IntValue<'ctx>,
|
||||
@ -15,7 +14,7 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq");
|
||||
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
|
||||
|
||||
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
|
@ -1212,7 +1212,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let align_ty = align_ty.into();
|
||||
|
||||
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();
|
||||
|
@ -207,7 +207,7 @@ pub fn gen_ndarray_eye<'ctx>(
|
||||
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
let llvm_usize = context.get_size_type();
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
|
||||
let nrows = context
|
||||
@ -244,7 +244,7 @@ pub fn gen_ndarray_identity<'ctx>(
|
||||
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
let llvm_usize = context.get_size_type();
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
|
||||
let n = context
|
||||
@ -325,8 +325,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
|
||||
|
||||
// Check shapes.
|
||||
let a_size = a.size(generator, ctx);
|
||||
let b_size = b.size(generator, ctx);
|
||||
let a_size = a.size(ctx);
|
||||
let b_size = b.size(ctx);
|
||||
let same_shape =
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
@ -353,9 +353,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b);
|
||||
Ok((a_iter, b_iter))
|
||||
},
|
||||
|generator, ctx, (a_iter, _b_iter)| {
|
||||
|_, ctx, (a_iter, _b_iter)| {
|
||||
// Only a_iter drives the condition, b_iter should have the same status.
|
||||
Ok(a_iter.has_element(generator, ctx))
|
||||
Ok(a_iter.has_element(ctx))
|
||||
},
|
||||
|_, ctx, _hooks, (a_iter, b_iter)| {
|
||||
let a_scalar = a_iter.get_scalar(ctx);
|
||||
@ -385,9 +385,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx.builder.build_store(result, new_result).unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|generator, ctx, (a_iter, b_iter)| {
|
||||
a_iter.next(generator, ctx);
|
||||
b_iter.next(generator, ctx);
|
||||
|_, ctx, (a_iter, b_iter)| {
|
||||
a_iter.next(ctx);
|
||||
b_iter.next(ctx);
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
|
@ -306,7 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Handle list item assignment
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
|
||||
|
||||
let target = generator
|
||||
@ -367,10 +367,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, key_ty)?
|
||||
.into_int_value();
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
|
||||
.unwrap();
|
||||
let index =
|
||||
ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap();
|
||||
|
||||
// handle negative index
|
||||
let is_negative = ctx
|
||||
@ -378,7 +376,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
ctx.get_size_type().const_zero(),
|
||||
"is_neg",
|
||||
)
|
||||
.unwrap();
|
||||
@ -460,7 +458,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
let target = broadcast_result.ndarrays[0];
|
||||
let value = broadcast_result.ndarrays[1];
|
||||
|
||||
target.copy_data_from(generator, ctx, value);
|
||||
target.copy_data_from(ctx, value);
|
||||
}
|
||||
_ => {
|
||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||
@ -484,7 +482,7 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
let var_assignment = ctx.var_assignment.clone();
|
||||
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_t = generator.get_size_type(ctx.ctx);
|
||||
let size_t = ctx.get_size_type();
|
||||
let zero = int32.const_zero();
|
||||
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
|
||||
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
||||
|
@ -152,7 +152,7 @@ impl<'ctx> ListType<'ctx> {
|
||||
_ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)),
|
||||
};
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
|
||||
None
|
||||
} else {
|
||||
@ -273,7 +273,7 @@ impl<'ctx> ListType<'ctx> {
|
||||
}
|
||||
|
||||
let plist = self.alloca_var(generator, ctx, name);
|
||||
plist.store_size(ctx, generator, len);
|
||||
plist.store_size(ctx, len);
|
||||
|
||||
let item = self.item.unwrap_or(self.llvm_usize.into());
|
||||
plist.create_data(ctx, item, None);
|
||||
@ -300,7 +300,7 @@ impl<'ctx> ListType<'ctx> {
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let plist = self.alloca_var(generator, ctx, name);
|
||||
|
||||
plist.store_size(ctx, generator, self.llvm_usize.const_zero());
|
||||
plist.store_size(ctx, self.llvm_usize.const_zero());
|
||||
plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None);
|
||||
|
||||
plist
|
||||
|
@ -67,9 +67,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
// Copy all contents from the list.
|
||||
irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(
|
||||
generator, ctx, list_value, ndarray,
|
||||
);
|
||||
irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray);
|
||||
|
||||
ndarray
|
||||
}
|
||||
@ -116,7 +114,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
}
|
||||
|
||||
// Set strides, the `data` is contiguous
|
||||
ndarray.set_strides_contiguous(generator, ctx);
|
||||
ndarray.set_strides_contiguous(ctx);
|
||||
|
||||
ndarray
|
||||
} else {
|
||||
|
@ -140,7 +140,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
|
||||
}
|
||||
|
@ -86,10 +86,10 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
.collect_vec();
|
||||
Ok((nditer, other_nditers))
|
||||
},
|
||||
|generator, ctx, (out_nditer, _in_nditers)| {
|
||||
|_, ctx, (out_nditer, _in_nditers)| {
|
||||
// We can simply use `out_nditer`'s `has_element()`.
|
||||
// `in_nditers`' `has_element()`s should return the same value.
|
||||
Ok(out_nditer.has_element(generator, ctx))
|
||||
Ok(out_nditer.has_element(ctx))
|
||||
},
|
||||
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||
@ -104,10 +104,10 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|generator, ctx, (out_nditer, in_nditers)| {
|
||||
|_, ctx, (out_nditer, in_nditers)| {
|
||||
// Advance all iterators
|
||||
out_nditer.next(generator, ctx);
|
||||
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
||||
out_nditer.next(ctx);
|
||||
in_nditers.iter().for_each(|nditer| nditer.next(ctx));
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
@ -158,7 +158,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
NDArrayType {
|
||||
@ -259,9 +259,9 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
|
||||
.unwrap();
|
||||
ndarray.store_itemsize(ctx, generator, itemsize);
|
||||
ndarray.store_itemsize(ctx, itemsize);
|
||||
|
||||
ndarray.store_ndims(ctx, generator, ndims);
|
||||
ndarray.store_ndims(ctx, ndims);
|
||||
|
||||
ndarray.create_shape(ctx, self.llvm_usize, ndims);
|
||||
ndarray.create_strides(ctx, self.llvm_usize, ndims);
|
||||
@ -307,7 +307,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
|
||||
.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
@ -342,7 +342,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
|
||||
.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
|
@ -52,7 +52,7 @@ impl<'ctx> TupleType<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Sanity check on object type.
|
||||
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {
|
||||
|
@ -418,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||
debug_assert_eq!(idx.get_type(), ctx.get_size_type());
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
|
@ -97,13 +97,8 @@ impl<'ctx> ListValue<'ctx> {
|
||||
}
|
||||
|
||||
/// Stores the `size` of this `list` into this instance.
|
||||
pub fn store_size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
size: IntValue<'ctx>,
|
||||
) {
|
||||
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
|
||||
pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) {
|
||||
debug_assert_eq!(size.get_type(), ctx.get_size_type());
|
||||
|
||||
self.len_field(ctx).set(ctx, self.value, size, self.name);
|
||||
}
|
||||
@ -213,7 +208,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||
debug_assert_eq!(idx.get_type(), ctx.get_size_type());
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
|
@ -112,7 +112,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
target_shape.base_ptr(ctx, generator),
|
||||
);
|
||||
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray);
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray);
|
||||
broadcast_ndarray
|
||||
}
|
||||
}
|
||||
@ -146,7 +146,7 @@ fn broadcast_shapes<'ctx, G, Shape>(
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx);
|
||||
|
||||
assert!(in_shape_entries
|
||||
@ -199,7 +199,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
) -> BroadcastAllResult<'ctx, G> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Infer the broadcast output ndims.
|
||||
let broadcast_ndims_int =
|
||||
|
@ -130,7 +130,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|
||||
|_, ctx| Ok(self.is_c_contiguous(ctx)),
|
||||
|_, ctx| {
|
||||
// This ndarray is contiguous.
|
||||
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
|
||||
@ -184,7 +184,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
// Copy shape and update strides
|
||||
let shape = carray.load_shape(ctx);
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||
ndarray.set_strides_contiguous(generator, ctx);
|
||||
ndarray.set_strides_contiguous(ctx);
|
||||
|
||||
// Share data
|
||||
let data = carray.load_data(ctx);
|
||||
|
@ -245,7 +245,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||
}
|
||||
RustNDIndex::Slice(in_rust_slice) => {
|
||||
let user_slice_ptr =
|
||||
SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx))
|
||||
SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type())
|
||||
.alloca_var(generator, ctx, None);
|
||||
in_rust_slice.write_to_slice(ctx, user_slice_ptr);
|
||||
|
||||
|
@ -35,7 +35,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
|
||||
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
|
||||
|
||||
// Deduce ndims of the result of matmul.
|
||||
@ -315,7 +315,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
let result_shape = result.shape();
|
||||
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
|
||||
|
||||
out_ndarray.copy_data_from(generator, ctx, result);
|
||||
out_ndarray.copy_data_from(ctx, result);
|
||||
out_ndarray
|
||||
}
|
||||
}
|
||||
|
@ -81,13 +81,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions `ndims` into this instance.
|
||||
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
ndims: IntValue<'ctx>,
|
||||
) {
|
||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) {
|
||||
debug_assert_eq!(ndims.get_type(), ctx.get_size_type());
|
||||
|
||||
let pndims = self.ptr_to_ndims(ctx);
|
||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||
@ -104,13 +99,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
}
|
||||
|
||||
/// Stores the size of each element `itemsize` into this instance.
|
||||
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
itemsize: IntValue<'ctx>,
|
||||
) {
|
||||
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
|
||||
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) {
|
||||
debug_assert_eq!(itemsize.get_type(), ctx.get_size_type());
|
||||
|
||||
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name);
|
||||
}
|
||||
@ -205,12 +195,12 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) {
|
||||
let nbytes = self.nbytes(generator, ctx);
|
||||
let nbytes = self.nbytes(ctx);
|
||||
|
||||
let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
|
||||
self.store_data(ctx, data);
|
||||
|
||||
self.set_strides_contiguous(generator, ctx);
|
||||
self.set_strides_contiguous(ctx);
|
||||
}
|
||||
|
||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||
@ -284,52 +274,32 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
}
|
||||
|
||||
/// Get the `np.size()` of this ndarray.
|
||||
pub fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
|
||||
pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_size(ctx, *self)
|
||||
}
|
||||
|
||||
/// Get the `ndarray.nbytes` of this ndarray.
|
||||
pub fn nbytes<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
|
||||
pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self)
|
||||
}
|
||||
|
||||
/// Get the `len()` of this ndarray.
|
||||
pub fn len<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
|
||||
pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_len(ctx, *self)
|
||||
}
|
||||
|
||||
/// Check if this ndarray is C-contiguous.
|
||||
///
|
||||
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
|
||||
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
|
||||
pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self)
|
||||
}
|
||||
|
||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
||||
///
|
||||
/// Update the ndarray's strides to make the ndarray contiguous.
|
||||
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) {
|
||||
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
|
||||
pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) {
|
||||
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self);
|
||||
}
|
||||
|
||||
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and
|
||||
@ -347,7 +317,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
let shape = self.shape();
|
||||
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||
unsafe { clone.create_data(generator, ctx) };
|
||||
clone.copy_data_from(generator, ctx, *self);
|
||||
clone.copy_data_from(ctx, *self);
|
||||
clone
|
||||
}
|
||||
|
||||
@ -357,14 +327,9 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// do not matter. The copying order is determined by how their flattened views look.
|
||||
///
|
||||
/// Panics if the `dtype`s of ndarrays are different.
|
||||
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src: NDArrayValue<'ctx>,
|
||||
) {
|
||||
pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) {
|
||||
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
|
||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||
irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self);
|
||||
}
|
||||
|
||||
/// Fill the ndarray with a scalar.
|
||||
@ -468,7 +433,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
if self.is_unsized() {
|
||||
// NOTE: `np.size(self) == 0` here is never possible.
|
||||
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||
let zero = ctx.get_size_type().const_zero();
|
||||
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||
|
||||
Some(value)
|
||||
@ -756,9 +721,9 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0)
|
||||
irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -770,7 +735,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx);
|
||||
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx);
|
||||
|
||||
// Current implementation is transparent - The returned pointer type is
|
||||
// already cast into the expected type, allowing for immediately
|
||||
@ -834,7 +799,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||
indices: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into());
|
||||
assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into());
|
||||
|
||||
let indices = TypedArrayLikeAdapter::from(
|
||||
indices.as_slice_value(ctx, generator),
|
||||
@ -867,7 +832,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||
indices: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let indices_size = indices.size(ctx, generator);
|
||||
let nidx_leq_ndims = ctx
|
||||
|
@ -53,20 +53,16 @@ impl<'ctx> NDIterValue<'ctx> {
|
||||
/// If `ndarray` is unsized, this returns true only for the first iteration.
|
||||
/// If `ndarray` is 0-sized, this always returns false.
|
||||
#[must_use]
|
||||
pub fn has_element<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self)
|
||||
pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_nditer_has_element(ctx, *self)
|
||||
}
|
||||
|
||||
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
|
||||
///
|
||||
/// If `ndarray` is unsized, this can only be called once.
|
||||
/// If `ndarray` is 0-sized, this can never be called.
|
||||
pub fn next<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) {
|
||||
irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self);
|
||||
pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) {
|
||||
irrt::ndarray::call_nac3_nditer_next(ctx, *self);
|
||||
}
|
||||
|
||||
fn element_field(
|
||||
@ -167,10 +163,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
|generator, ctx| {
|
||||
Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self))
|
||||
},
|
||||
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)),
|
||||
|_, ctx, nditer| Ok(nditer.has_element(ctx)),
|
||||
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||
|generator, ctx, nditer| {
|
||||
nditer.next(generator, ctx);
|
||||
|_, ctx, nditer| {
|
||||
nditer.next(ctx);
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
|
||||
) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let zero = llvm_usize.const_zero();
|
||||
let one = llvm_usize.const_int(1, false);
|
||||
|
||||
|
@ -70,7 +70,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
|
||||
|
||||
// Resolve negative indices
|
||||
let size = self.size(generator, ctx);
|
||||
let size = self.size(ctx);
|
||||
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false);
|
||||
let dst_shape = dst_ndarray.shape();
|
||||
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
|
||||
@ -84,10 +84,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|
||||
|_, ctx| Ok(self.is_c_contiguous(ctx)),
|
||||
|generator, ctx| {
|
||||
// Reshape is possible without copying
|
||||
dst_ndarray.set_strides_contiguous(generator, ctx);
|
||||
dst_ndarray.set_strides_contiguous(ctx);
|
||||
dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator));
|
||||
|
||||
Ok(())
|
||||
@ -97,7 +97,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
unsafe {
|
||||
dst_ndarray.create_data(generator, ctx);
|
||||
}
|
||||
dst_ndarray.copy_data_from(generator, ctx, *self);
|
||||
dst_ndarray.copy_data_from(ctx, *self);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|
@ -1278,11 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
let size = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
ndarray.size(generator, ctx),
|
||||
ctx.ctx.i32_type(),
|
||||
"",
|
||||
)
|
||||
.build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "")
|
||||
.unwrap();
|
||||
Ok(Some(size.into()))
|
||||
}),
|
||||
|
Loading…
Reference in New Issue
Block a user