diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs new file mode 100644 index 0000000..0d63b22 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -0,0 +1,187 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; +use itertools::Itertools; + +use crate::codegen::{ + stmt::gen_for_callback, + types::{ + ndarray::{NDArrayType, NDIterType}, + ProxyType, + }, + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ArrayLikeValue, ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayType<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` + /// elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when + /// iterating through the input `ndarrays` after broadcasting. The output of `mapping` is the + /// result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing + /// ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[NDArrayValue<'ctx>], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result<>::Value, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = self.broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + .construct_uninitialized(generator, ctx, None); + result_ndarray.copy_shape_from_array( + generator, + ctx, + broadcast_result.shape.base_ptr(ctx, generator), + ); + unsafe { + result_ndarray.create_data(generator, ctx); + } + result_ndarray + } + + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| { + NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) + }) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_element()`. + // `in_nditers`' `has_element()`s should return the same value. + Ok(out_nditer.has_element(generator, ctx)) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + let in_scalars = + in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a + /// scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars + /// or ndarrays or a mix of them as their inputs and produces either an ndarray with broadcast, + /// or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following + /// behaviors: + /// + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is + /// converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> + /// ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a + /// [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be + /// 'as-ndarray'-ed into ndarrays, then all inputs (now all ndarrays) will be passed to + /// [`NDArrayValue::broadcasting_starmap`] and **create** a new ndarray with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: BasicTypeEnum<'ctx>, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = + inputs.iter().map(BasicValueEnum::<'ctx>::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().copied().collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(value)) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ret_dtype, + &inputs.iter().map(NDArrayValue::get_type).collect_vec(), + ) + .broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 316d0f3..43712be 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -30,6 +30,7 @@ mod broadcast; mod contiguous; pub mod factory; mod indexing; +mod map; mod nditer; /// Proxy type for a `ndarray` type in LLVM. diff --git a/nac3core/src/codegen/values/ndarray/map.rs b/nac3core/src/codegen/values/ndarray/map.rs new file mode 100644 index 0000000..72d1bf9 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/map.rs @@ -0,0 +1,69 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; + +use crate::codegen::{ + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + self.get_type().broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a + /// [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new + /// ndarray of the results will be created and returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: BasicTypeEnum<'ctx>, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 6c8c9aa..89f88e7 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -31,6 +31,7 @@ pub use nditer::*; mod broadcast; mod contiguous; mod indexing; +mod map; mod nditer; pub mod shape; mod view; @@ -540,6 +541,26 @@ impl<'ctx> NDArrayValue<'ctx> { ScalarOrNDArray::NDArray(*self) } } + + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_shape: impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) { + let ndarray_shape = self.shape(); + let output_shape = out_shape; + + irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + &ndarray_shape, + &output_shape, + ); + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { @@ -1081,3 +1102,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } } + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a +/// function will create a new ndarray and store the result in it. +#[derive(Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: BasicTypeEnum<'ctx> }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayValue<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +}