From 3b87bd36f3d86156d6cc040acf29945dc318158a Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 14 Jul 2024 14:17:51 +0800 Subject: [PATCH] core: irrt ndarray setup --- nac3core/irrt/irrt/ndarray/ndarray.hpp | 134 ++++++++ nac3core/irrt/irrt/ndarray/ndarray_util.hpp | 107 ++++++ nac3core/irrt/irrt_everything.hpp | 3 +- nac3core/src/codegen/irrt/classes.rs | 35 +- nac3core/src/codegen/irrt/mod.rs | 1 + nac3core/src/codegen/irrt/new.rs | 3 +- nac3core/src/codegen/irrt/numpy.rs | 354 ++++++++++++++++++++ nac3core/src/codegen/optics.rs | 21 +- 8 files changed, 649 insertions(+), 9 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/ndarray.hpp create mode 100644 nac3core/irrt/irrt/ndarray/ndarray_util.hpp create mode 100644 nac3core/src/codegen/irrt/numpy.rs diff --git a/nac3core/irrt/irrt/ndarray/ndarray.hpp b/nac3core/irrt/irrt/ndarray/ndarray.hpp new file mode 100644 index 00000000..29c2e739 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/ndarray.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include +#include + +namespace { + +// The NDArray object. `SizeT` is the *signed* size type of this ndarray. +// +// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT +// +// Some resources you might find helpful: +// - The official numpy implementations: +// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst +// - On strides (about reshaping, slicing, C-contagiousness, etc) +// - https://ajcr.net/stride-guide-part-1/. +// - https://ajcr.net/stride-guide-part-2/. +// - https://ajcr.net/stride-guide-part-3/. +template +struct NDArray { + // The underlying data this `ndarray` is pointing to. + // + // NOTE: Formally this should be of type `void *`, but clang + // translates `void *` to `i8 *` when run with `-S -emit-llvm`, + // so we will put `uint8_t *` here for clarity. + // + // This pointer should point to the first element of the ndarray directly + uint8_t *data; + + // The number of bytes of a single element in `data`. + // + // The `SizeT` is treated as `unsigned`. + SizeT itemsize; + + // The number of dimensions of this shape. + // + // The `SizeT` is treated as `unsigned`. + SizeT ndims; + + // Array shape, with length equal to `ndims`. + // + // The `SizeT` is treated as `unsigned`. + // + // NOTE: `shape` can contain 0. + // (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`) + SizeT *shape; + + // Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`. + // + // The `SizeT` is treated as `signed`. + // + // NOTE: `strides` can have negative numbers. + // (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`) + SizeT *strides; + + // Calculate the size/# of elements of an `ndarray`. + // This function corresponds to `np.size()` or `ndarray.size` + SizeT size() { + return ndarray_util::calc_size_from_shape(ndims, shape); + } + + // Calculate the number of bytes of its content of an `ndarray` *in its view*. + // This function corresponds to `ndarray.nbytes` + SizeT nbytes() { + return this->size() * itemsize; + } + + // Set the strides of the ndarray with `ndarray_util::set_strides_by_shape` + void set_strides_by_shape() { + ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape); + } + + uint8_t* get_pelement_by_indices(const SizeT *indices) { + uint8_t* element = data; + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) + element += indices[dim_i] * strides[dim_i]; + return element; + } + + uint8_t* get_nth_pelement(SizeT nth) { + SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims); + ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth); + return get_pelement_by_indices(indices); + } + + // Get the pointer to the nth element of the ndarray as if it were flattened. + uint8_t* checked_get_nth_pelement(ErrorContext* errctx, SizeT nth) { + SizeT arr_size = this->size(); + if (!(0 <= nth && nth < arr_size)) { + errctx->set_error( + errctx->error_ids->index_error, + "index {0} is out of bounds, valid range is {1} <= index < {2}", + nth, 0, arr_size + ); + return 0; + } + return get_nth_pelement(nth); + } +}; +} + +extern "C" { +uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return ndarray->size(); +} + +uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return ndarray->size(); +} + +uint32_t __nac3_ndarray_nbytes(NDArray* ndarray) { + return ndarray->nbytes(); +} + +uint64_t __nac3_ndarray_nbytes64(NDArray* ndarray) { + return ndarray->nbytes(); +} + +void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) { + ndarray_util::assert_shape_no_negative(errctx, ndims, shape); +} + +void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) { + ndarray_util::assert_shape_no_negative(errctx, ndims, shape); +} + +void __nac3_ndarray_set_strides_by_shape(NDArray* ndarray) { + ndarray->set_strides_by_shape(); +} + +void __nac3_ndarray_set_strides_by_shape64(NDArray* ndarray) { + ndarray->set_strides_by_shape(); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/ndarray_util.hpp b/nac3core/irrt/irrt/ndarray/ndarray_util.hpp new file mode 100644 index 00000000..e99804c8 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/ndarray_util.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include + +namespace { +namespace ndarray_util { + +// Throw an error if there is an axis with negative dimension +template +void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + if (shape[axis] < 0) { + errctx->set_error( + errctx->error_ids->value_error, + "negative dimensions are not allowed; axis {0} has dimension {1}", + axis, shape[axis] + ); + return; + } + } +} + +// Compute the size/# of elements of an ndarray given its shape +template +SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis]; + return size; +} + +// Compute the strides of an ndarray given an ndarray `shape` +// and assuming that the ndarray is *fully C-contagious*. +// +// You might want to read up on https://ajcr.net/stride-guide-part-1/. +template +void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndims; i++) { + int axis = ndims - i - 1; + dst_strides[axis] = stride_product * itemsize; + stride_product *= shape[axis]; + } +} + +template +void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) { + for (int32_t i = 0; i < ndims; i++) { + int32_t axis = ndims - i - 1; + int32_t dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} + +template +bool can_broadcast_shape_to( + const SizeT target_ndims, + const SizeT *target_shape, + const SizeT src_ndims, + const SizeT *src_shape +) { + /* + // See https://numpy.org/doc/stable/user/basics.broadcasting.html + + This function handles this example: + ``` + Image (3d array): 256 x 256 x 3 + Scale (1d array): 3 + Result (3d array): 256 x 256 x 3 + ``` + + Other interesting examples to consider: + - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true` + - `can_broadcast_shape_to([3], [3, 1]) == false` + - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true` + + In cases when the shapes contain zero(es): + - `can_broadcast_shape_to([0], [1]) == true` + - `can_broadcast_shape_to([0], [2]) == false` + - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true` + - `can_broadcast_shape_to([4, 3], [0, 3]) == false` + - `can_broadcast_shape_to([4, 3], [0, 0]) == false` + */ + + // This is essentially doing the following in Python: + // `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)` + for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) { + SizeT target_axis = target_ndims - i - 1; + SizeT src_axis = src_ndims - i - 1; + + bool target_dim_exists = target_axis >= 0; + bool src_dim_exists = src_axis >= 0; + + SizeT target_dim = target_dim_exists ? target_shape[target_axis] : 1; + SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1; + + bool ok = src_dim == 1 || target_dim == src_dim; + if (!ok) return false; + } + + return true; +} +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index e7919124..696a7dd0 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -3,4 +3,5 @@ #include #include #include -#include \ No newline at end of file +#include +#include \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/classes.rs b/nac3core/src/codegen/irrt/classes.rs index 3b10e6f2..fe2f5f39 100644 --- a/nac3core/src/codegen/irrt/classes.rs +++ b/nac3core/src/codegen/irrt/classes.rs @@ -1,6 +1,11 @@ -use inkwell::types::{BasicTypeEnum, IntType}; +use inkwell::types::IntType; -use crate::codegen::optics::{AddressLens, FieldBuilder, GepGetter, IntLens, StructureOptic}; +use crate::codegen::{ + optics::{ + Address, AddressLens, ArraySlice, FieldBuilder, GepGetter, IntLens, Optic, StructureOptic, + }, + CodeGenContext, +}; #[derive(Debug, Clone)] pub struct StrLens<'ctx> { @@ -36,13 +41,18 @@ pub struct NpArrayFields<'ctx> { pub strides: GepGetter>>, } +// Note: NpArrayLens's ElementOptic is purely for type-safety and type-guidances +// The underlying LLVM ndarray doesn't care, it only holds an opaque (uint8_t*) pointer to the elements. #[derive(Debug, Clone, Copy)] -pub struct NpArrayLens<'ctx> { +pub struct NpArrayLens<'ctx, ElementOptic> { pub size_type: IntType<'ctx>, - pub elem_type: BasicTypeEnum<'ctx>, + pub element_optic: ElementOptic, } -impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { +// NDArray is *frequently* used, so here is a type alias +pub type NpArray<'ctx, ElementOptic> = Address<'ctx, NpArrayLens<'ctx, ElementOptic>>; + +impl<'ctx, ElementOptic: Optic<'ctx>> StructureOptic<'ctx> for NpArrayLens<'ctx, ElementOptic> { type Fields = NpArrayFields<'ctx>; fn struct_name(&self) -> &'static str { @@ -63,6 +73,21 @@ impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { } } +// Other convenient utilities for NpArray +impl<'ctx, ElementOptic: Optic<'ctx>> NpArray<'ctx, ElementOptic> { + pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { + let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); + let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape"); + ArraySlice { num_elements: ndims, base: shape_base_ptr } + } + + pub fn strides_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { + let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); + let strides_base_ptr = self.focus(ctx, |fields| &fields.strides).load(ctx, "strides"); + ArraySlice { num_elements: ndims, base: strides_base_ptr } + } +} + pub struct IrrtStringFields<'ctx> { pub buffer: GepGetter>>, pub capacity: GepGetter>, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7276507e..57d4deb9 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,5 +1,6 @@ use crate::typecheck::typedef::Type; +pub mod numpy; mod test; use super::{ diff --git a/nac3core/src/codegen/irrt/new.rs b/nac3core/src/codegen/irrt/new.rs index 50060687..c529bff7 100644 --- a/nac3core/src/codegen/irrt/new.rs +++ b/nac3core/src/codegen/irrt/new.rs @@ -19,7 +19,8 @@ fn get_size_variant(ty: IntType) -> SizeVariant { } } -fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String { +#[must_use] +pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String { let mut fn_name = fn_name.to_owned(); match get_size_variant(ty) { SizeVariant::Bits32 => { diff --git a/nac3core/src/codegen/irrt/numpy.rs b/nac3core/src/codegen/irrt/numpy.rs new file mode 100644 index 00000000..c81ba9da --- /dev/null +++ b/nac3core/src/codegen/irrt/numpy.rs @@ -0,0 +1,354 @@ +use std::marker::PhantomData; + +use inkwell::{ + types::BasicType, + values::{BasicValueEnum, IntValue}, +}; + +use crate::{ + codegen::{ + classes::{ListValue, UntypedArrayLikeAccessor}, + optics::{Address, AddressLens, ArraySlice, IntLens, Ixed, Optic}, + stmt::gen_for_callback_incrementing, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +use super::{ + classes::{ErrorContextLens, NpArray, NpArrayLens}, + new::{ + check_error_context, get_sized_dependent_function_name, prepare_error_context, + FunctionBuilder, + }, +}; + +type ProducerWriteToArray<'ctx, G, ElementOptic> = Box< + dyn Fn( + &mut G, + &mut CodeGenContext<'ctx, '_>, + &ArraySlice<'ctx, ElementOptic>, + ) -> Result<(), String> + + 'ctx, +>; + +struct Producer<'ctx, G: CodeGenerator + ?Sized, ElementOptic> { + pub count: IntValue<'ctx>, + pub write_to_array: ProducerWriteToArray<'ctx, G, ElementOptic>, +} + +/// TODO: UPDATE DOCUMENTATION +/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The `shape` parameter used to construct the `NDArray`. +/// +/// ### Notes on `shape` +/// +/// Just like numpy, the `shape` argument can be: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to +/// learn how `shape` gets from being a Python user expression to here. +fn parse_input_shape_arg<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, +) -> Producer<'ctx, G, IntLens<'ctx>> +where + G: CodeGenerator + ?Sized, +{ + let size_type = generator.get_size_type(ctx.ctx); + + match &*ctx.unifier.get_ty(shape_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of ints; e.g., `np.empty([600, 800, 3])` + + // A list has to be a PointerValue + let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None); + + // Create `Producer` + let ndims = shape_list.load_size(ctx, Some("count")); + Producer { + count: ndims, + write_to_array: Box::new(move |ctx, generator, dst_array| { + // Basically iterate through the list and write to `dst_slice` accordingly + let init_val = size_type.const_zero(); + let max_val = (ndims, false); + let incr_val = size_type.const_int(1, false); + gen_for_callback_incrementing( + ctx, + generator, + init_val, + max_val, + |generator, ctx, _hooks, axis| { + // Get the dimension at `axis` + let dim = + shape_list.data().get(ctx, generator, &axis, None).into_int_value(); + + // Cast `dim` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted") + .unwrap(); + + // Write + dst_array.ix(ctx, axis, "dim").store(ctx, &dim); + Ok(()) + }, + incr_val, + ) + }), + } + } + TypeEnum::TTuple { ty: tuple_types } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + // Get the length/size of the tuple, which also happens to be the value of `ndims`. + let ndims = tuple_types.len(); + + // A tuple has to be a StructValue + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + let shape_tuple = shape.into_struct_value(); + + Producer { + count: size_type.const_int(ndims as u64, false), + write_to_array: Box::new(move |_generator, ctx, dst_array| { + for axis in 0..ndims { + // Get the dimension at `axis` + let dim = ctx + .builder + .build_extract_value( + shape_tuple, + axis as u32, + format!("dim{axis}").as_str(), + ) + .unwrap() + .into_int_value(); + + // Cast `dim` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted") + .unwrap(); + + // Write + dst_array + .ix(ctx, size_type.const_int(axis as u64, false), "dim") + .store(ctx, &dim); + } + Ok(()) + }), + } + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + // The value has to be an integer + let shape_int = shape.into_int_value(); + + Producer { + count: size_type.const_int(1, false), + write_to_array: Box::new(move |_generator, ctx, dst_array| { + // Cast `shape_int` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted") + .unwrap(); + + // Write + dst_array + .ix(ctx, size_type.const_zero() /* Only index 0 is set */, "dim") + .store(ctx, &dim); + + Ok(()) + }), + } + } + _ => panic!("parse_input_shape_arg encountered unknown type"), + } +} + +pub fn alloca_ndarray<'ctx, G, ElementOptic: Optic<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + element_optic: ElementOptic, + ndims: IntValue<'ctx>, + name: &str, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + let size_type = generator.get_size_type(ctx.ctx); + + let itemsize = element_optic.get_llvm_type(ctx.ctx).size_of().unwrap(); + let itemsize = + ctx.builder.build_int_s_extend_or_bit_cast(itemsize, size_type, "itemsize").unwrap(); + + let shape = ctx.builder.build_array_alloca(size_type, ndims, "shape").unwrap(); + let strides = ctx.builder.build_array_alloca(size_type, ndims, "strides").unwrap(); + + let ndarray = NpArrayLens { size_type, element_optic }.alloca(ctx, name); + + // Set ndims, itemsize; and allocate shape and store on the stack + ndarray.focus(ctx, |fields| &fields.ndims).store(ctx, &ndims); + ndarray.focus(ctx, |fields| &fields.itemsize).store(ctx, &itemsize); + ndarray + .focus(ctx, |fields| &fields.shape) + .store(ctx, &Address { addressee_optic: IntLens(size_type), address: shape }); + ndarray + .focus(ctx, |fields| &fields.strides) + .store(ctx, &Address { addressee_optic: IntLens(size_type), address: strides }); + + Ok(ndarray) +} + +enum NDArrayInitMode<'ctx, G: CodeGenerator + ?Sized> { + NDim { ndim: IntValue<'ctx>, _phantom: PhantomData<&'ctx G> }, + Shape { shape: Producer<'ctx, G, IntLens<'ctx>> }, + ShapeAndAllocaData { shape: Producer<'ctx, G, IntLens<'ctx>> }, +} + +/// TODO: DOCUMENT ME +fn alloca_ndarray_and_init<'ctx, G, ElementOptic: Optic<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + element_optic: ElementOptic, + init_mode: NDArrayInitMode<'ctx, G>, + name: &str, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + // It is implemented verbosely in order to make the initialization modes super clear in their intent. + match init_mode { + NDArrayInitMode::NDim { ndim: ndims, _phantom } => { + let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + Ok(ndarray) + } + NDArrayInitMode::Shape { shape } => { + let ndims = shape.count; + let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + + // Fill `ndarray.shape` + (shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?; + + // Check if `shape` has bad inputs + call_nac3_ndarray_util_assert_shape_no_negative( + generator, + ctx, + ndims, + &ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), + ); + + // NOTE: DO NOT DO `set_strides_by_shape` HERE. + // Simply this is because we specified that `SetShape` wouldn't do `set_strides_by_shape` + + Ok(ndarray) + } + NDArrayInitMode::ShapeAndAllocaData { shape } => { + let ndims = shape.count; + let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + + // Fill `ndarray.shape` + (shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?; + + // Check if `shape` has bad inputs + call_nac3_ndarray_util_assert_shape_no_negative( + generator, + ctx, + ndims, + &ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), + ); + + // Now we populate `ndarray.data` by alloca-ing. + // But first, we need to know the size of the ndarray to know how many elements to alloca, + // since calculating nbytes of an ndarray requires `ndarray.shape` to be set. + let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray); + + // Alloca `data` and assign it to `ndarray.data` + let data_ptr = + ctx.builder.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes, "data").unwrap(); + ndarray.focus(ctx, |fields| &fields.data).store( + ctx, + &Address { addressee_optic: IntLens::int8(ctx.ctx), address: data_ptr }, + ); + + // Finally, do `set_strides_by_shape` + // Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are. + call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray); + + Ok(ndarray) + } + } +} + +fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndims: IntValue<'ctx>, + shape: &Address<'ctx, IntLens<'ctx>>, +) { + let size_type = generator.get_size_type(ctx.ctx); + + let errctx = prepare_error_context(ctx); + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name( + size_type, + "__nac3_ndarray_util_assert_shape_no_negative", + ), + ) + .arg("errctx", &AddressLens(ErrorContextLens), &errctx) + .arg("ndims", &IntLens(size_type), &ndims) + .arg("shape", &AddressLens(IntLens(size_type)), shape) + .returning_void(); + check_error_context(generator, ctx, &errctx); +} + +fn call_nac3_ndarray_set_strides_by_shape< + 'ctx, + G: CodeGenerator + ?Sized, + ElementOptic: Optic<'ctx>, +>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: &NpArray<'ctx, ElementOptic>, +) { + let size_type = generator.get_size_type(ctx.ctx); + + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name( + size_type, + "__nac3_ndarray_util_assert_shape_no_negative", + ), + ) + .arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray) + .returning_void(); +} + +fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized, ElementOptic: Optic<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: &NpArray<'ctx, ElementOptic>, +) -> IntValue<'ctx> { + let size_type = generator.get_size_type(ctx.ctx); + + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name( + size_type, + "__nac3_ndarray_util_assert_shape_no_negative", + ), + ) + .arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray) + .returning("nbytes", &IntLens(size_type)) +} diff --git a/nac3core/src/codegen/optics.rs b/nac3core/src/codegen/optics.rs index dcfd727c..ddc021a1 100644 --- a/nac3core/src/codegen/optics.rs +++ b/nac3core/src/codegen/optics.rs @@ -58,6 +58,23 @@ pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {} #[derive(Debug, Clone, Copy)] pub struct IntLens<'ctx>(pub IntType<'ctx>); +impl<'ctx> IntLens<'ctx> { + #[must_use] + pub fn int8(ctx: &'ctx Context) -> IntLens<'ctx> { + IntLens(ctx.i8_type()) + } + + #[must_use] + pub fn int32(ctx: &'ctx Context) -> IntLens<'ctx> { + IntLens(ctx.i32_type()) + } + + #[must_use] + pub fn int64(ctx: &'ctx Context) -> IntLens<'ctx> { + IntLens(ctx.i64_type()) + } +} + impl<'ctx> Optic<'ctx> for IntLens<'ctx> { type Value = IntValue<'ctx>; @@ -111,7 +128,7 @@ impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> { } pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> { - self.cast_to(ctx, IntLens(ctx.ctx.i8_type())) + self.cast_to(ctx, IntLens::int8(ctx.ctx)) } } @@ -126,7 +143,7 @@ pub struct AddressLens(pub AddresseeOptic); impl AddressLens { pub fn new_opaque<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> AddressLens> { - AddressLens(IntLens(ctx.ctx.i8_type())) + AddressLens(IntLens::int8(ctx.ctx)) } }