forked from M-Labs/nac3
core: irrt ndarray setup
This commit is contained in:
parent
b4d5b2a41f
commit
3b87bd36f3
|
@ -0,0 +1,134 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/ndarray/ndarray_util.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||||||
|
//
|
||||||
|
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||||||
|
//
|
||||||
|
// Some resources you might find helpful:
|
||||||
|
// - The official numpy implementations:
|
||||||
|
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||||
|
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||||||
|
// - https://ajcr.net/stride-guide-part-1/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-2/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-3/.
|
||||||
|
template <typename SizeT>
|
||||||
|
struct NDArray {
|
||||||
|
// The underlying data this `ndarray` is pointing to.
|
||||||
|
//
|
||||||
|
// NOTE: Formally this should be of type `void *`, but clang
|
||||||
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||||
|
// so we will put `uint8_t *` here for clarity.
|
||||||
|
//
|
||||||
|
// This pointer should point to the first element of the ndarray directly
|
||||||
|
uint8_t *data;
|
||||||
|
|
||||||
|
// The number of bytes of a single element in `data`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT itemsize;
|
||||||
|
|
||||||
|
// The number of dimensions of this shape.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT ndims;
|
||||||
|
|
||||||
|
// Array shape, with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
//
|
||||||
|
// NOTE: `shape` can contain 0.
|
||||||
|
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||||||
|
SizeT *shape;
|
||||||
|
|
||||||
|
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `signed`.
|
||||||
|
//
|
||||||
|
// NOTE: `strides` can have negative numbers.
|
||||||
|
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||||||
|
SizeT *strides;
|
||||||
|
|
||||||
|
// Calculate the size/# of elements of an `ndarray`.
|
||||||
|
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||||||
|
SizeT size() {
|
||||||
|
return ndarray_util::calc_size_from_shape(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||||||
|
// This function corresponds to `ndarray.nbytes`
|
||||||
|
SizeT nbytes() {
|
||||||
|
return this->size() * itemsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||||||
|
void set_strides_by_shape() {
|
||||||
|
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t* get_pelement_by_indices(const SizeT *indices) {
|
||||||
|
uint8_t* element = data;
|
||||||
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
||||||
|
element += indices[dim_i] * strides[dim_i];
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t* get_nth_pelement(SizeT nth) {
|
||||||
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
||||||
|
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
||||||
|
return get_pelement_by_indices(indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the pointer to the nth element of the ndarray as if it were flattened.
|
||||||
|
uint8_t* checked_get_nth_pelement(ErrorContext* errctx, SizeT nth) {
|
||||||
|
SizeT arr_size = this->size();
|
||||||
|
if (!(0 <= nth && nth < arr_size)) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->index_error,
|
||||||
|
"index {0} is out of bounds, valid range is {1} <= index < {2}",
|
||||||
|
nth, 0, arr_size
|
||||||
|
);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return get_nth_pelement(nth);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray->nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray->nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||||
|
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
||||||
|
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||||
|
ndarray->set_strides_by_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||||
|
ndarray->set_strides_by_shape();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray_util {
|
||||||
|
|
||||||
|
// Throw an error if there is an axis with negative dimension
|
||||||
|
template <typename SizeT>
|
||||||
|
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) {
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
if (shape[axis] < 0) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"negative dimensions are not allowed; axis {0} has dimension {1}",
|
||||||
|
axis, shape[axis]
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the size/# of elements of an ndarray given its shape
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||||
|
SizeT size = 1;
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the strides of an ndarray given an ndarray `shape`
|
||||||
|
// and assuming that the ndarray is *fully C-contagious*.
|
||||||
|
//
|
||||||
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
int axis = ndims - i - 1;
|
||||||
|
dst_strides[axis] = stride_product * itemsize;
|
||||||
|
stride_product *= shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (int32_t i = 0; i < ndims; i++) {
|
||||||
|
int32_t axis = ndims - i - 1;
|
||||||
|
int32_t dim = shape[axis];
|
||||||
|
|
||||||
|
indices[axis] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
bool can_broadcast_shape_to(
|
||||||
|
const SizeT target_ndims,
|
||||||
|
const SizeT *target_shape,
|
||||||
|
const SizeT src_ndims,
|
||||||
|
const SizeT *src_shape
|
||||||
|
) {
|
||||||
|
/*
|
||||||
|
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
|
||||||
|
This function handles this example:
|
||||||
|
```
|
||||||
|
Image (3d array): 256 x 256 x 3
|
||||||
|
Scale (1d array): 3
|
||||||
|
Result (3d array): 256 x 256 x 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Other interesting examples to consider:
|
||||||
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||||
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||||
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||||
|
|
||||||
|
In cases when the shapes contain zero(es):
|
||||||
|
- `can_broadcast_shape_to([0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0], [2]) == false`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This is essentially doing the following in Python:
|
||||||
|
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||||
|
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||||
|
SizeT target_axis = target_ndims - i - 1;
|
||||||
|
SizeT src_axis = src_ndims - i - 1;
|
||||||
|
|
||||||
|
bool target_dim_exists = target_axis >= 0;
|
||||||
|
bool src_dim_exists = src_axis >= 0;
|
||||||
|
|
||||||
|
SizeT target_dim = target_dim_exists ? target_shape[target_axis] : 1;
|
||||||
|
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
||||||
|
|
||||||
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||||
|
if (!ok) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,3 +4,4 @@
|
||||||
#include <irrt/error_context.hpp>
|
#include <irrt/error_context.hpp>
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
||||||
|
#include <irrt/ndarray/ndarray.hpp>
|
|
@ -1,6 +1,11 @@
|
||||||
use inkwell::types::{BasicTypeEnum, IntType};
|
use inkwell::types::IntType;
|
||||||
|
|
||||||
use crate::codegen::optics::{AddressLens, FieldBuilder, GepGetter, IntLens, StructureOptic};
|
use crate::codegen::{
|
||||||
|
optics::{
|
||||||
|
Address, AddressLens, ArraySlice, FieldBuilder, GepGetter, IntLens, Optic, StructureOptic,
|
||||||
|
},
|
||||||
|
CodeGenContext,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct StrLens<'ctx> {
|
pub struct StrLens<'ctx> {
|
||||||
|
@ -36,13 +41,18 @@ pub struct NpArrayFields<'ctx> {
|
||||||
pub strides: GepGetter<AddressLens<IntLens<'ctx>>>,
|
pub strides: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: NpArrayLens's ElementOptic is purely for type-safety and type-guidances
|
||||||
|
// The underlying LLVM ndarray doesn't care, it only holds an opaque (uint8_t*) pointer to the elements.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct NpArrayLens<'ctx> {
|
pub struct NpArrayLens<'ctx, ElementOptic> {
|
||||||
pub size_type: IntType<'ctx>,
|
pub size_type: IntType<'ctx>,
|
||||||
pub elem_type: BasicTypeEnum<'ctx>,
|
pub element_optic: ElementOptic,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> {
|
// NDArray is *frequently* used, so here is a type alias
|
||||||
|
pub type NpArray<'ctx, ElementOptic> = Address<'ctx, NpArrayLens<'ctx, ElementOptic>>;
|
||||||
|
|
||||||
|
impl<'ctx, ElementOptic: Optic<'ctx>> StructureOptic<'ctx> for NpArrayLens<'ctx, ElementOptic> {
|
||||||
type Fields = NpArrayFields<'ctx>;
|
type Fields = NpArrayFields<'ctx>;
|
||||||
|
|
||||||
fn struct_name(&self) -> &'static str {
|
fn struct_name(&self) -> &'static str {
|
||||||
|
@ -63,6 +73,21 @@ impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Other convenient utilities for NpArray
|
||||||
|
impl<'ctx, ElementOptic: Optic<'ctx>> NpArray<'ctx, ElementOptic> {
|
||||||
|
pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> {
|
||||||
|
let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims");
|
||||||
|
let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape");
|
||||||
|
ArraySlice { num_elements: ndims, base: shape_base_ptr }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn strides_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> {
|
||||||
|
let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims");
|
||||||
|
let strides_base_ptr = self.focus(ctx, |fields| &fields.strides).load(ctx, "strides");
|
||||||
|
ArraySlice { num_elements: ndims, base: strides_base_ptr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct IrrtStringFields<'ctx> {
|
pub struct IrrtStringFields<'ctx> {
|
||||||
pub buffer: GepGetter<AddressLens<IntLens<'ctx>>>,
|
pub buffer: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
pub capacity: GepGetter<IntLens<'ctx>>,
|
pub capacity: GepGetter<IntLens<'ctx>>,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
|
pub mod numpy;
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
|
|
|
@ -19,7 +19,8 @@ fn get_size_variant(ty: IntType) -> SizeVariant {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String {
|
#[must_use]
|
||||||
|
pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String {
|
||||||
let mut fn_name = fn_name.to_owned();
|
let mut fn_name = fn_name.to_owned();
|
||||||
match get_size_variant(ty) {
|
match get_size_variant(ty) {
|
||||||
SizeVariant::Bits32 => {
|
SizeVariant::Bits32 => {
|
||||||
|
|
|
@ -0,0 +1,354 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use inkwell::{
|
||||||
|
types::BasicType,
|
||||||
|
values::{BasicValueEnum, IntValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
classes::{ListValue, UntypedArrayLikeAccessor},
|
||||||
|
optics::{Address, AddressLens, ArraySlice, IntLens, Ixed, Optic},
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
classes::{ErrorContextLens, NpArray, NpArrayLens},
|
||||||
|
new::{
|
||||||
|
check_error_context, get_sized_dependent_function_name, prepare_error_context,
|
||||||
|
FunctionBuilder,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
type ProducerWriteToArray<'ctx, G, ElementOptic> = Box<
|
||||||
|
dyn Fn(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, '_>,
|
||||||
|
&ArraySlice<'ctx, ElementOptic>,
|
||||||
|
) -> Result<(), String>
|
||||||
|
+ 'ctx,
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct Producer<'ctx, G: CodeGenerator + ?Sized, ElementOptic> {
|
||||||
|
pub count: IntValue<'ctx>,
|
||||||
|
pub write_to_array: ProducerWriteToArray<'ctx, G, ElementOptic>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: UPDATE DOCUMENTATION
|
||||||
|
/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
|
///
|
||||||
|
/// ### Notes on `shape`
|
||||||
|
///
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
||||||
|
/// learn how `shape` gets from being a Python user expression to here.
|
||||||
|
fn parse_input_shape_arg<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
) -> Producer<'ctx, G, IntLens<'ctx>>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(shape_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
// A list has to be a PointerValue
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None);
|
||||||
|
|
||||||
|
// Create `Producer`
|
||||||
|
let ndims = shape_list.load_size(ctx, Some("count"));
|
||||||
|
Producer {
|
||||||
|
count: ndims,
|
||||||
|
write_to_array: Box::new(move |ctx, generator, dst_array| {
|
||||||
|
// Basically iterate through the list and write to `dst_slice` accordingly
|
||||||
|
let init_val = size_type.const_zero();
|
||||||
|
let max_val = (ndims, false);
|
||||||
|
let incr_val = size_type.const_int(1, false);
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
init_val,
|
||||||
|
max_val,
|
||||||
|
|generator, ctx, _hooks, axis| {
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim =
|
||||||
|
shape_list.data().get(ctx, generator, &axis, None).into_int_value();
|
||||||
|
|
||||||
|
// Cast `dim` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array.ix(ctx, axis, "dim").store(ctx, &dim);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
incr_val,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty: tuple_types } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
|
let ndims = tuple_types.len();
|
||||||
|
|
||||||
|
// A tuple has to be a StructValue
|
||||||
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||||
|
let shape_tuple = shape.into_struct_value();
|
||||||
|
|
||||||
|
Producer {
|
||||||
|
count: size_type.const_int(ndims as u64, false),
|
||||||
|
write_to_array: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
for axis in 0..ndims {
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(
|
||||||
|
shape_tuple,
|
||||||
|
axis as u32,
|
||||||
|
format!("dim{axis}").as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
|
||||||
|
// Cast `dim` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array
|
||||||
|
.ix(ctx, size_type.const_int(axis as u64, false), "dim")
|
||||||
|
.store(ctx, &dim);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
// The value has to be an integer
|
||||||
|
let shape_int = shape.into_int_value();
|
||||||
|
|
||||||
|
Producer {
|
||||||
|
count: size_type.const_int(1, false),
|
||||||
|
write_to_array: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
// Cast `shape_int` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array
|
||||||
|
.ix(ctx, size_type.const_zero() /* Only index 0 is set */, "dim")
|
||||||
|
.store(ctx, &dim);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("parse_input_shape_arg encountered unknown type"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloca_ndarray<'ctx, G, ElementOptic: Optic<'ctx>>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
element_optic: ElementOptic,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<NpArray<'ctx, ElementOptic>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let itemsize = element_optic.get_llvm_type(ctx.ctx).size_of().unwrap();
|
||||||
|
let itemsize =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(itemsize, size_type, "itemsize").unwrap();
|
||||||
|
|
||||||
|
let shape = ctx.builder.build_array_alloca(size_type, ndims, "shape").unwrap();
|
||||||
|
let strides = ctx.builder.build_array_alloca(size_type, ndims, "strides").unwrap();
|
||||||
|
|
||||||
|
let ndarray = NpArrayLens { size_type, element_optic }.alloca(ctx, name);
|
||||||
|
|
||||||
|
// Set ndims, itemsize; and allocate shape and store on the stack
|
||||||
|
ndarray.focus(ctx, |fields| &fields.ndims).store(ctx, &ndims);
|
||||||
|
ndarray.focus(ctx, |fields| &fields.itemsize).store(ctx, &itemsize);
|
||||||
|
ndarray
|
||||||
|
.focus(ctx, |fields| &fields.shape)
|
||||||
|
.store(ctx, &Address { addressee_optic: IntLens(size_type), address: shape });
|
||||||
|
ndarray
|
||||||
|
.focus(ctx, |fields| &fields.strides)
|
||||||
|
.store(ctx, &Address { addressee_optic: IntLens(size_type), address: strides });
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
enum NDArrayInitMode<'ctx, G: CodeGenerator + ?Sized> {
|
||||||
|
NDim { ndim: IntValue<'ctx>, _phantom: PhantomData<&'ctx G> },
|
||||||
|
Shape { shape: Producer<'ctx, G, IntLens<'ctx>> },
|
||||||
|
ShapeAndAllocaData { shape: Producer<'ctx, G, IntLens<'ctx>> },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: DOCUMENT ME
|
||||||
|
fn alloca_ndarray_and_init<'ctx, G, ElementOptic: Optic<'ctx>>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
element_optic: ElementOptic,
|
||||||
|
init_mode: NDArrayInitMode<'ctx, G>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<NpArray<'ctx, ElementOptic>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
// It is implemented verbosely in order to make the initialization modes super clear in their intent.
|
||||||
|
match init_mode {
|
||||||
|
NDArrayInitMode::NDim { ndim: ndims, _phantom } => {
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?;
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
NDArrayInitMode::Shape { shape } => {
|
||||||
|
let ndims = shape.count;
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?;
|
||||||
|
|
||||||
|
// Fill `ndarray.shape`
|
||||||
|
(shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?;
|
||||||
|
|
||||||
|
// Check if `shape` has bad inputs
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndims,
|
||||||
|
&ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// NOTE: DO NOT DO `set_strides_by_shape` HERE.
|
||||||
|
// Simply this is because we specified that `SetShape` wouldn't do `set_strides_by_shape`
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
NDArrayInitMode::ShapeAndAllocaData { shape } => {
|
||||||
|
let ndims = shape.count;
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?;
|
||||||
|
|
||||||
|
// Fill `ndarray.shape`
|
||||||
|
(shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?;
|
||||||
|
|
||||||
|
// Check if `shape` has bad inputs
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndims,
|
||||||
|
&ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Now we populate `ndarray.data` by alloca-ing.
|
||||||
|
// But first, we need to know the size of the ndarray to know how many elements to alloca,
|
||||||
|
// since calculating nbytes of an ndarray requires `ndarray.shape` to be set.
|
||||||
|
let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray);
|
||||||
|
|
||||||
|
// Alloca `data` and assign it to `ndarray.data`
|
||||||
|
let data_ptr =
|
||||||
|
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes, "data").unwrap();
|
||||||
|
ndarray.focus(ctx, |fields| &fields.data).store(
|
||||||
|
ctx,
|
||||||
|
&Address { addressee_optic: IntLens::int8(ctx.ctx), address: data_ptr },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Finally, do `set_strides_by_shape`
|
||||||
|
// Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are.
|
||||||
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
shape: &Address<'ctx, IntLens<'ctx>>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let errctx = prepare_error_context(ctx);
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), &errctx)
|
||||||
|
.arg("ndims", &IntLens(size_type), &ndims)
|
||||||
|
.arg("shape", &AddressLens(IntLens(size_type)), shape)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, &errctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_set_strides_by_shape<
|
||||||
|
'ctx,
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
ElementOptic: Optic<'ctx>,
|
||||||
|
>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: &NpArray<'ctx, ElementOptic>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized, ElementOptic: Optic<'ctx>>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: &NpArray<'ctx, ElementOptic>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray)
|
||||||
|
.returning("nbytes", &IntLens(size_type))
|
||||||
|
}
|
|
@ -58,6 +58,23 @@ pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {}
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct IntLens<'ctx>(pub IntType<'ctx>);
|
pub struct IntLens<'ctx>(pub IntType<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> IntLens<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn int8(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i8_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn int32(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i32_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn int64(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i64_type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> Optic<'ctx> for IntLens<'ctx> {
|
impl<'ctx> Optic<'ctx> for IntLens<'ctx> {
|
||||||
type Value = IntValue<'ctx>;
|
type Value = IntValue<'ctx>;
|
||||||
|
|
||||||
|
@ -111,7 +128,7 @@ impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> {
|
pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> {
|
||||||
self.cast_to(ctx, IntLens(ctx.ctx.i8_type()))
|
self.cast_to(ctx, IntLens::int8(ctx.ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +143,7 @@ pub struct AddressLens<AddresseeOptic>(pub AddresseeOptic);
|
||||||
|
|
||||||
impl<AddresseeOptic> AddressLens<AddresseeOptic> {
|
impl<AddresseeOptic> AddressLens<AddresseeOptic> {
|
||||||
pub fn new_opaque<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> AddressLens<IntLens<'ctx>> {
|
pub fn new_opaque<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> AddressLens<IntLens<'ctx>> {
|
||||||
AddressLens(IntLens(ctx.ctx.i8_type()))
|
AddressLens(IntLens::int8(ctx.ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue