From 3efae534f7f88e02389c1fa9856ed1603f94be70 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 17:03:23 +0800 Subject: [PATCH] core/ndstrides: add NDArrayOut, broadcast_map and map --- nac3core/src/codegen/model/any.rs | 2 +- nac3core/src/codegen/model/array.rs | 51 ----- nac3core/src/codegen/model/core.rs | 30 +++ nac3core/src/codegen/model/structure.rs | 2 +- nac3core/src/codegen/object/ndarray/map.rs | 220 +++++++++++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 51 +++++ 6 files changed, 303 insertions(+), 53 deletions(-) create mode 100644 nac3core/src/codegen/object/ndarray/map.rs diff --git a/nac3core/src/codegen/model/any.rs b/nac3core/src/codegen/model/any.rs index d843f764..84a444d6 100644 --- a/nac3core/src/codegen/model/any.rs +++ b/nac3core/src/codegen/model/any.rs @@ -10,7 +10,7 @@ use super::*; /// A [`Model`] of any [`BasicTypeEnum`]. /// -/// Use this when you don't need/cannot have any static types to escape from the [`Model`] abstraction. +/// Use this when you cannot know the type beforehand or cannot be abstracted with [`Model`]. #[derive(Debug, Clone, Copy)] pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>); diff --git a/nac3core/src/codegen/model/array.rs b/nac3core/src/codegen/model/array.rs index bcf7ab66..829c27fe 100644 --- a/nac3core/src/codegen/model/array.rs +++ b/nac3core/src/codegen/model/array.rs @@ -70,54 +70,3 @@ impl<'ctx, Element: Model<'ctx>> Instance<'ctx, Ptr>> { Ptr(self.model.0.item).check_value(generator, ctx.ctx, ptr).unwrap() } } - -/// Like [`ArrayModel`] but length is strongly-typed. -#[derive(Debug, Clone, Copy, Default)] -pub struct NArrayModel(pub Element); -pub type NArray<'ctx, const LEN: u32, Element> = Instance<'ctx, NArrayModel>; - -impl<'ctx, const LEN: u32, Element: Model<'ctx>> NArrayModel { - /// Forget the `LEN` constant generic and get an [`ArrayModel`] with the same length. - pub fn forget_len(&self) -> Array { - Array { item: self.0, len: LEN } - } -} - -impl<'ctx, const LEN: u32, Element: Model<'ctx>> Model<'ctx> for NArrayModel { - type Value = ArrayValue<'ctx>; - type Type = ArrayType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { - // Convenient implementation - self.forget_len().get_type(generator, ctx) - } - - fn check_type, G: CodeGenerator + ?Sized>( - &self, - generator: &mut G, - ctx: &'ctx Context, - ty: T, - ) -> Result<(), ModelError> { - // Convenient implementation - self.forget_len().check_type(generator, ctx, ty) - } -} - -impl<'ctx, const LEN: u32, Element: Model<'ctx>> Instance<'ctx, Ptr>> { - /// Get the pointer to the `i`-th (0-based) array element. - pub fn at_const( - &self, - generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, - i: u32, - name: &str, - ) -> Instance<'ctx, Ptr> { - assert!(i < LEN); - - let zero = ctx.ctx.i32_type().const_zero(); - let i = ctx.ctx.i32_type().const_int(u64::from(i), false); - let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() }; - - Ptr(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap() - } -} diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 3ca1ea39..f9950e0a 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -1,6 +1,7 @@ use std::fmt; use inkwell::{context::Context, types::*, values::*}; +use itertools::Itertools; use super::*; use crate::codegen::{CodeGenContext, CodeGenerator}; @@ -110,6 +111,35 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy { let p = generator.gen_array_var_alloc(ctx, ty, len, name)?; Ok(Ptr(*self).believe_value(PointerValue::from(p))) } + + fn const_array( + &self, + generator: &mut G, + ctx: &'ctx Context, + values: &[Instance<'ctx, Self>], + ) -> Instance<'ctx, Array> { + macro_rules! make { + ($t:expr, $into_value:expr) => { + $t.const_array( + &values + .iter() + .map(|x| $into_value(x.value.as_basic_value_enum())) + .collect_vec(), + ) + }; + } + + let value = match self.get_type(generator, ctx).as_basic_type_enum() { + BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value), + BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value), + BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value), + BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value), + BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value), + BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value), + }; + + Array { len: values.len() as u32, item: *self }.check_value(generator, ctx, value).unwrap() + } } #[derive(Debug, Clone, Copy)] diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs index e82ed01d..19c511f7 100644 --- a/nac3core/src/codegen/model/structure.rs +++ b/nac3core/src/codegen/model/structure.rs @@ -134,7 +134,7 @@ impl<'ctx, S: StructKind<'ctx>> Struct { /// Create a constant struct value. /// /// This function also validates `fields` and panic when there is something wrong. - fn const_struct( + pub fn const_struct( &self, generator: &mut G, ctx: &'ctx Context, diff --git a/nac3core/src/codegen/object/ndarray/map.rs b/nac3core/src/codegen/object/ndarray/map.rs new file mode 100644 index 00000000..4fcefe23 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/map.rs @@ -0,0 +1,220 @@ +use inkwell::values::BasicValueEnum; +use itertools::Itertools; + +use crate::{ + codegen::{ + object::ndarray::{AnyObject, NDArrayObject}, + stmt::gen_for_callback, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through + /// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[Self], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims); + result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); + result_ndarray.create_data(generator, ctx); + result_ndarray + } + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out( + generator, + ctx, + broadcast_result.ndims, + broadcast_result.shape, + ); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterHandle::new(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray)) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_next()`. + // `in_nditers`' `has_next()`s should return the same value. + Ok(out_nditer.has_next(generator, ctx).value) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + + let in_scalars = in_nditers + .iter() + .map(|nditer| nditer.get_scalar(generator, ctx).value) + .collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } + + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + NDArrayObject::broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them + /// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following behaviors: + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays, + /// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray + /// with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: Type, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = inputs.iter().map(AnyObject::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value })) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } + + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and + /// returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: Type, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index d26e204f..72848efb 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -2,6 +2,7 @@ pub mod array; pub mod broadcast; pub mod factory; pub mod indexing; +pub mod map; pub mod nditer; pub mod shape_util; pub mod view; @@ -20,6 +21,7 @@ use crate::{ call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + call_nac3_ndarray_util_assert_output_shape_same, }, model::*, CodeGenContext, CodeGenerator, @@ -506,6 +508,31 @@ impl<'ctx> NDArrayObject<'ctx> { ndarray.instance.set(ctx, |f| f.data, data); ndarray } + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_ndims: u64, + out_shape: Instance<'ctx, Ptr>>, + ) { + let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx); + let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape); + + let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims); + let output_shape = out_shape; + + call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + ndarray_ndims, + ndarray_shape, + output_ndims, + output_shape, + ); + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. @@ -588,3 +615,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } } + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will +/// create a new ndarray and store the result in it. +#[derive(Debug, Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: Type }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayObject<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +}