This commit is contained in:
David Mak 2024-11-21 14:25:05 +08:00
parent ebeb4f6dca
commit acd976289f
9 changed files with 605 additions and 186 deletions

View File

@ -721,7 +721,9 @@ fn format_rpc_ret<'ctx>(
); );
} }
ndarray.create_data(generator, ctx, llvm_elem_ty, num_elements); unsafe {
ndarray.create_data(generator, ctx, num_elements);
}
let ndarray_data = ndarray.data().base_ptr(ctx, generator); let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 = let ndarray_data_i8 =

View File

@ -32,7 +32,7 @@ use super::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var, gen_var,
}, },
types::{ListType, ProxyType}, types::{ListType, NDArrayType, ProxyType},
values::{ values::{
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
TypedArrayLikeAccessor, UntypedArrayLikeAccessor, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
@ -43,7 +43,7 @@ use crate::{
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::unpack_ndarray_var_tys,
DefinitionId, TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -2595,14 +2595,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
_ => 1, _ => 1,
}; };
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
@ -2797,26 +2789,15 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
let num_dims = v.load_ndims(ctx);
let num_dims = ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap();
// Create a new array, remove the top dimension from the dimension-size-list, and copy the // Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over // elements over
let subscripted_ndarray = let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t)
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; .construct_uninitialized(generator, ctx, num_dims, None);
let ndarray = NDArrayValue::from_pointer_value(
subscripted_ndarray,
llvm_ndarray_data_t,
None,
llvm_usize,
None,
);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
@ -2858,7 +2839,9 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
.builder .builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap(); .unwrap();
ndarray.create_data(generator, ctx, llvm_ndarray_data_t, ndarray_num_elems); unsafe {
ndarray.create_data(generator, ctx, ndarray_num_elems);
}
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic( call_memcpy_generic(
@ -3604,3 +3587,90 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
_ => unimplemented!(), _ => unimplemented!(),
})) }))
} }
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
pub fn create_fn_and_call<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
(params, is_var_args): (&[BasicTypeEnum<'ctx>], bool),
args: &[BasicValueEnum<'ctx>],
call_value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let intrinsic_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
let params = params.iter().copied().map(BasicTypeEnum::into).collect_vec();
let fn_type = if let Some(ret_type) = ret_type {
ret_type.fn_type(params.as_slice(), is_var_args)
} else {
ctx.ctx.void_type().fn_type(params.as_slice(), is_var_args)
};
ctx.module.add_function(fn_name, fn_type, None)
});
if let Some(configure) = configure {
configure(&intrinsic_fn);
}
let args = args.iter().copied().map(BasicValueEnum::into).collect_vec();
ctx.builder
.build_call(intrinsic_fn, args.as_slice(), call_value_name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(Either::left)
.unwrap()
}
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
///
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
/// parameters and arguments to be specified as tuples to better indicate the expected type and
/// actual value of each parameter-argument pair of the call.
pub fn create_and_call_function<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
params: &[(BasicTypeEnum<'ctx>, BasicValueEnum<'ctx>)],
value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let param_tys = params.iter().map(|(ty, _)| ty).copied().map(BasicTypeEnum::into).collect_vec();
let arg_values =
params.iter().map(|(_, value)| value).copied().map(BasicValueEnum::into).collect_vec();
create_fn_and_call(
ctx,
fn_name,
ret_type,
(param_tys.as_slice(), false),
arg_values.as_slice(),
value_name,
configure,
)
}
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
///
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
/// only arguments to be specified and performs inference for the parameter types using
/// [`BasicValueEnum::get_type`].
pub fn infer_and_call_function<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
fn_name: &str,
ret_type: Option<BasicTypeEnum<'ctx>>,
args: &[BasicValueEnum<'ctx>],
value_name: Option<&str>,
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
let param_tys = args.iter().map(BasicValueEnum::get_type).collect_vec();
create_fn_and_call(
ctx,
fn_name,
ret_type,
(param_tys.as_slice(), false),
args,
value_name,
configure,
)
}

View File

@ -59,6 +59,27 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
irrt_mod irrt_mod
} }
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
///
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
#[must_use]
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python. /// NO numeric slice in python.

View File

@ -1,134 +1,258 @@
use crate::codegen::{CodeGenContext, CodeGenerator}; use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`. use crate::codegen::{
/// expr::create_and_call_function,
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. irrt::get_usize_dependent_function_name,
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. types::NDArrayType,
#[must_use] values::{NDArrayValue, ProxyValue},
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>( CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: &str, ndims: IntValue<'ctx>,
) -> String { shape: PointerValue<'ctx>,
let mut name = name.to_owned(); ) {
match generator.get_size_type(ctx.ctx).get_bit_width() { let llvm_usize = generator.get_size_type(ctx.ctx);
32 => {} let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
64 => name.push_str("64"),
bit_width => { let name = get_usize_dependent_function_name(
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") generator,
} ctx,
} "__nac3_ndarray_util_assert_shape_no_negative",
name );
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
None,
None,
);
} }
// pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G, generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
// ndims: Instance<'ctx, Int<SizeT>>, ndarray_ndims: IntValue<'ctx>,
// shape: Instance<'ctx, Ptr<Int<SizeT>>>, ndarray_shape: PointerValue<'ctx>,
// ) { output_ndims: IntValue<'ctx>,
// let name = get_usize_dependent_function_name( output_shape: IntValue<'ctx>,
// generator, ) {
// ctx, let llvm_usize = generator.get_size_type(ctx.ctx);
// "__nac3_ndarray_util_assert_shape_no_negative", let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
// );
// FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void(); let name = get_usize_dependent_function_name(
// } generator,
// ctx,
// pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( "__nac3_ndarray_util_assert_output_shape_same",
// generator: &mut G, );
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray_ndims: Instance<'ctx, Int<SizeT>>, create_and_call_function(
// ndarray_shape: Instance<'ctx, Ptr<Int<SizeT>>>, ctx,
// output_ndims: Instance<'ctx, Int<SizeT>>, &name,
// output_shape: Instance<'ctx, Ptr<Int<SizeT>>>, Some(llvm_usize.into()),
// ) { &[
// let name = get_usize_dependent_function_name( (llvm_usize.into(), ndarray_ndims.into()),
// generator, (llvm_pusize.into(), ndarray_shape.into()),
// ctx, (llvm_usize.into(), output_ndims.into()),
// "__nac3_ndarray_util_assert_output_shape_same", (llvm_pusize.into(), output_shape.into()),
// ); ],
// FnCall::builder(generator, ctx, &name) None,
// .arg(ndarray_ndims) None,
// .arg(ndarray_shape) );
// .arg(output_ndims) }
// .arg(output_shape)
// .returning_void(); pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
// } generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( ndarray: NDArrayValue<'ctx>,
// generator: &mut G, ) -> IntValue<'ctx> {
// ctx: &mut CodeGenContext<'ctx, '_>, let llvm_usize = generator.get_size_type(ctx.ctx);
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
// ) -> Instance<'ctx, Int<SizeT>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
// } create_and_call_function(
// ctx,
// pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( &name,
// generator: &mut G, Some(llvm_usize.into()),
// ctx: &mut CodeGenContext<'ctx, '_>, &[(llvm_ndarray.into(), ndarray.as_base_value().into())],
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, Some("size"),
// ) -> Instance<'ctx, Int<SizeT>> { None,
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); )
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes") .map(BasicValueEnum::into_int_value)
// } .unwrap()
// }
// pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G, pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
// ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ctx: &mut CodeGenContext<'ctx, '_>,
// ) -> Instance<'ctx, Int<SizeT>> { ndarray: NDArrayValue<'ctx>,
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); ) -> IntValue<'ctx> {
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len") let llvm_usize = generator.get_size_type(ctx.ctx);
// } let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
//
// pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>, create_and_call_function(
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ctx,
// ) -> Instance<'ctx, Int<Bool>> { &name,
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); Some(llvm_usize.into()),
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous") &[(llvm_ndarray.into(), ndarray.as_base_value().into())],
// } Some("nbytes"),
// None,
// pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( )
// generator: &mut G, .map(BasicValueEnum::into_int_value)
// ctx: &mut CodeGenContext<'ctx, '_>, .unwrap()
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, }
// index: Instance<'ctx, Int<SizeT>>,
// ) -> Instance<'ctx, Ptr<Int<Byte>>> { pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); generator: &mut G,
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement") ctx: &mut CodeGenContext<'ctx, '_>,
// } ndarray: NDArrayValue<'ctx>,
// ) -> IntValue<'ctx> {
// pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx);
// generator: &mut G, let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
// indices: Instance<'ctx, Ptr<Int<SizeT>>>,
// ) -> Instance<'ctx, Ptr<Int<Byte>>> { create_and_call_function(
// let name = ctx,
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); &name,
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement") Some(llvm_usize.into()),
// } &[(llvm_ndarray.into(), ndarray.as_base_value().into())],
// Some("len"),
// pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( None,
// generator: &mut G, )
// ctx: &mut CodeGenContext<'ctx, '_>, .map(BasicValueEnum::into_int_value)
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, .unwrap()
// ) { }
// let name =
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void(); generator: &mut G,
// } ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: NDArrayValue<'ctx>,
// pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( ) -> IntValue<'ctx> {
// generator: &mut G, let llvm_i1 = ctx.ctx.bool_type();
// ctx: &mut CodeGenContext<'ctx, '_>, let llvm_usize = generator.get_size_type(ctx.ctx);
// src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
// dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) { let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
// FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); create_and_call_function(
// } ctx,
&name,
Some(llvm_i1.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("is_c_contiguous"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
index: IntValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[
(llvm_ndarray.into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.into()),
],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function(
ctx,
&name,
None,
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
None,
None,
);
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
create_and_call_function(
ctx,
&name,
None,
&[
(llvm_ndarray.into(), src_ndarray.as_base_value().into()),
(llvm_ndarray.into(), dst_ndarray.as_base_value().into()),
],
None,
None,
);
}

View File

@ -201,6 +201,52 @@ pub fn call_memcpy_generic<'ctx>(
call_memcpy(ctx, dest, src, len, is_volatile); call_memcpy(ctx, dest, src, len, is_volatile);
} }
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
/// Moreover, `len` now refers to the number of elements (rather than bytes) to copy.
pub fn call_memcpy_generic_array<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let len = ctx.builder.build_int_cast(len, llvm_sizeof_expr_t, "").unwrap();
let len = ctx.builder.build_int_mul(
len,
src_elem_t.size_of().unwrap(),
""
).unwrap();
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function) /// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
/// ///
/// Arguments: /// Arguments:

View File

@ -41,6 +41,7 @@ use crate::{
}; };
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
#[deprecated = "Use NDArrayType::construct_uninitialized instead."]
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -84,6 +85,7 @@ where
) -> Result<IntValue<'ctx>, String>, ) -> Result<IntValue<'ctx>, String>,
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
// Assert that all dimensions are non-negative // Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -123,10 +125,10 @@ where
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
let num_dims = shape_len_fn(generator, ctx, shape)?; let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
.construct_uninitialized(generator, ctx, num_dims, None);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
@ -215,7 +217,9 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
&ndarray.shape().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
ndarray.create_data(generator, ctx, llvm_ndarray_data_t, ndarray_num_elems); unsafe {
ndarray.create_data(generator, ctx, ndarray_num_elems);
}
ndarray ndarray
} }
@ -1262,6 +1266,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray = if slices.is_empty() { let ndarray = if slices.is_empty() {
create_ndarray_dyn_shape( create_ndarray_dyn_shape(
@ -1275,8 +1280,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
}, },
)? )?
} else { } else {
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); .construct_uninitialized(generator, ctx, this.load_ndims(ctx), None);
let ndims = this.load_ndims(ctx); let ndims = this.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndims); ndarray.create_shape(ctx, llvm_usize, ndims);

View File

@ -82,7 +82,7 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType // TODO: Move this as a member of this Struct
#[must_use] #[must_use]
fn fields( fn fields(
ctx: impl AsContextRef<'ctx>, ctx: impl AsContextRef<'ctx>,
@ -103,7 +103,7 @@ impl<'ctx> NDArrayType<'ctx> {
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use] #[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { pub fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* } // struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
// //
// * data : Pointer to an array containing the array data // * data : Pointer to an array containing the array data
@ -189,22 +189,20 @@ impl<'ctx> NDArrayType<'ctx> {
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndims: u64, // ndims: u64,
ndims: IntValue<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.new_value(generator, ctx, name); let ndarray = self.new_value(generator, ctx, name);
let itemsize = ctx let itemsize =
.builder ctx.builder.build_int_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "").unwrap();
.build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap();
ndarray.store_itemsize(ctx, generator, itemsize); ndarray.store_itemsize(ctx, generator, itemsize);
let ndims_val = self.llvm_usize.const_int(ndims, false); ndarray.store_ndims(ctx, generator, ndims);
ndarray.store_ndims(ctx, generator, ndims_val);
ndarray.create_shape(ctx, self.llvm_usize, ndims_val); ndarray.create_shape(ctx, self.llvm_usize, ndims);
ndarray.create_strides(ctx, self.llvm_usize, ndims_val); ndarray.create_strides(ctx, self.llvm_usize, ndims);
ndarray ndarray
} }
@ -220,7 +218,14 @@ impl<'ctx> NDArrayType<'ctx> {
shape: &[u64], shape: &[u64],
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = self.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(shape.len() as u64, false),
name,
);
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();
@ -250,7 +255,14 @@ impl<'ctx> NDArrayType<'ctx> {
shape: &[IntValue<'ctx>], shape: &[IntValue<'ctx>],
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = self.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(shape.len() as u64, false),
name,
);
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();

View File

@ -145,7 +145,7 @@ where
} }
/// Sets the value of this field for a given `obj`. /// Sets the value of this field for a given `obj`.
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) { pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value); obj.set_field_at_index(self.index, value);
} }

View File

@ -10,7 +10,7 @@ use super::{
}; };
use crate::codegen::{ use crate::codegen::{
irrt, irrt,
llvm_intrinsics::call_int_umin, llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
type_aligned_alloca, type_aligned_alloca,
types::{structure::StructField, NDArrayType}, types::{structure::StructField, NDArrayType},
@ -79,7 +79,9 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
fn itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { fn itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx, self.llvm_usize).itemsize self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.itemsize
} }
/// Stores the size of each element `itemsize` into this instance. /// Stores the size of each element `itemsize` into this instance.
@ -179,19 +181,29 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Convenience method for creating a new array storing data elements with the given element /// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`. /// type `elem_ty` and `size`.
pub fn create_data<G: CodeGenerator + ?Sized>( ///
/// The data buffer will be allocated on the stack, and is considered to be owned by this ndarray instance.
///
/// # Safety
///
/// `shape` and `itemsize` of the ndarray must be initialized.
pub unsafe fn create_data<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>, size: IntValue<'ctx>,
) { ) {
// let itemsize =
// ctx.builder.build_int_cast(self.load_itemsize(ctx), size.get_type(), "").unwrap();
let itemsize = let itemsize =
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap(); ctx.builder.build_int_cast(self.dtype.size_of().unwrap(), size.get_type(), "").unwrap();
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap(); let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
// let nbytes = self.nbytes(generator, ctx);
let data = type_aligned_alloca(generator, ctx, elem_ty, nbytes, None); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
self.store_data(ctx, data); self.store_data(ctx, data);
// self.set_strides_contiguous(generator, ctx);
} }
/// Returns a proxy object to the field storing the data of this `NDArray`. /// Returns a proxy object to the field storing the data of this `NDArray`.
@ -199,6 +211,133 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self) NDArrayDataProxy(self)
} }
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.shape().base_ptr(ctx, generator),
shape,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
strides: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.strides().base_ptr(ctx, generator),
strides,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
) {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
}
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {