From 00800ba6ee994a359319cd35b56346461d888261 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 22:42:04 +0800 Subject: [PATCH] core/ndstrides: add ContiguousNDArray Currently this is used to interop with nalgebra. --- .../src/codegen/object/ndarray/contiguous.rs | 134 ++++++++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 1 + 2 files changed, 135 insertions(+) create mode 100644 nac3core/src/codegen/object/ndarray/contiguous.rs diff --git a/nac3core/src/codegen/object/ndarray/contiguous.rs b/nac3core/src/codegen/object/ndarray/contiguous.rs new file mode 100644 index 00000000..6f067679 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/contiguous.rs @@ -0,0 +1,134 @@ +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, +}; + +use super::NDArrayObject; + +/// Fields of [`ContiguousNDArray`] +pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, + pub data: F::Out>, +} + +/// An ndarray without strides and non-opaque `data` field in NAC3. +#[derive(Debug, Clone, Copy)] +pub struct ContiguousNDArray { + /// [`Model`] of the items. + pub item: M, +} + +impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray { + type Fields> = ContiguousNDArrayFields<'ctx, F, Item>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + data: traversal.add("data", Ptr(self.item)), + } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a [`ContiguousNDArray`] from the contents of this ndarray. + /// + /// This function may or may not be expensive depending on if this ndarray has contiguous data. + /// + /// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of + /// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there. + /// + /// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`] + /// will share memory with this ndarray. + /// + /// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and + /// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`] + /// if you don't care/cannot know the [`Model`] in advance. + pub fn make_contiguous_ndarray>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + item_model: Item, + ) -> Instance<'ctx, Ptr>>> { + // Sanity check on `self.dtype` and `item_model`. + let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); + item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap(); + + let cdarray_model = Struct(ContiguousNDArray { item: item_model }); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb"); + let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); + let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); + + // Allocate and setup the resulting [`ContiguousNDArray`]. + let result = cdarray_model.alloca(generator, ctx); + + // Set ndims and shape. + let ndims = self.ndims_llvm(generator, ctx.ctx); + result.set(ctx, |f| f.ndims, ndims); + + let shape = self.instance.get(generator, ctx, |f| f.shape); + result.set(ctx, |f| f.shape, shape); + + let is_contiguous = self.is_c_contiguous(generator, ctx); + ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap(); + + // Inserting into then_bb; This ndarray is contiguous. + ctx.builder.position_at_end(then_bb); + let data = self.instance.get(generator, ctx, |f| f.data); + let data = data.pointer_cast(generator, ctx, item_model); + result.set(ctx, |f| f.data, data); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`. + // `make_copy` produces an ndarray with contiguous `data`. + ctx.builder.position_at_end(else_bb); + let copied_ndarray = self.make_copy(generator, ctx); + let data = copied_ndarray.instance.get(generator, ctx, |f| f.data); + let data = data.pointer_cast(generator, ctx, item_model); + result.set(ctx, |f| f.data, data); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Reposition to end_bb for continuation + ctx.builder.position_at_end(end_bb); + + result + } + + /// Create an [`NDArrayObject`] from a [`ContiguousNDArray`]. + /// + /// The operation is super cheap. The newly created [`NDArrayObject`] will share the + /// same memory as the [`ContiguousNDArray`]. + /// + /// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite + /// the fact that the information should be contained within the [`ContiguousNDArray`]. + pub fn from_contiguous_ndarray>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + carray: Instance<'ctx, Ptr>>>, + dtype: Type, + ndims: u64, + ) -> Self { + // Sanity check on `dtype` and `contiguous_array`'s `Item` model. + let dtype_llvm = ctx.get_llvm_type(generator, dtype); + carray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap(); + + // TODO: Debug assert `ndims == carray.ndims` to catch bugs. + + // Allocate the resulting ndarray. + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); + + // Copy shape and update strides + let shape = carray.get(generator, ctx, |f| f.shape); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.set_strides_contiguous(generator, ctx); + + // Share data + let data = carray.get(generator, ctx, |f| f.data).pointer_cast(generator, ctx, Int(Byte)); + ndarray.instance.set(ctx, |f| f.data, data); + + ndarray + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 39f8884f..cd305790 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,5 +1,6 @@ pub mod array; pub mod broadcast; +pub mod contiguous; pub mod factory; pub mod indexing; pub mod map;