From f090b3b4d2ad5a08d2e2abb206238a856fe04805 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 11:38:05 +0800 Subject: [PATCH] core/ndstrides: define ndarray with strides --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/def.hpp | 47 +++++++++++++++++ nac3core/src/codegen/mod.rs | 13 ++--- nac3core/src/codegen/object/mod.rs | 1 + nac3core/src/codegen/object/ndarray/mod.rs | 59 ++++++++++++++++++++++ 5 files changed, 114 insertions(+), 9 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/def.hpp create mode 100644 nac3core/src/codegen/object/ndarray/mod.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 31ba09f1..e5227917 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,4 +1,5 @@ #include #include #include -#include +#include +#include \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/def.hpp b/nac3core/irrt/irrt/ndarray/def.hpp new file mode 100644 index 00000000..270109b1 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/def.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include + +namespace +{ +/** + * @brief The NDArray object + * + * The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst + */ +template struct NDArray +{ + /** + * @brief The underlying data this `ndarray` is pointing to. + * + * Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized. + */ + uint8_t *data; + + /** + * @brief The number of bytes of a single element in `data`. + */ + SizeT itemsize; + + /** + * @brief The number of dimensions of this shape. + */ + SizeT ndims; + + /** + * @brief The NDArray shape, with length equal to `ndims`. + * + * Note that it may contain 0. + */ + SizeT *shape; + + /** + * @brief Array strides, with length equal to `ndims` + * + * The stride values are in units of bytes, not number of elements. + * + * Note that `strides` can have negative values. + */ + SizeT *strides; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 697fcf89..523a7ec7 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ - codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, + codegen::classes::{ListType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -24,7 +24,9 @@ use inkwell::{ AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; +use model::*; use nac3parser::ast::{Location, Stmt, StrRef}; +use object::ndarray::NDArray; use parking_lot::{Condvar, Mutex}; use std::collections::{HashMap, HashSet}; use std::sync::{ @@ -491,12 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, dtype, - ); - - NDArrayType::new(generator, ctx, element_type).as_base_type().into() + Ptr(Struct(NDArray)).get_type(generator, ctx).as_basic_type_enum() } _ => unreachable!( diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs index 93d2f453..f9abd433 100644 --- a/nac3core/src/codegen/object/mod.rs +++ b/nac3core/src/codegen/object/mod.rs @@ -1 +1,2 @@ pub mod any; +pub mod ndarray; diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs new file mode 100644 index 00000000..53f71df4 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -0,0 +1,59 @@ +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, + typecheck::typedef::Type, +}; + +use super::any::AnyObject; + +/// Fields of [`NDArray`] +pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { + pub data: F::Out>>, + pub itemsize: F::Out>, + pub ndims: F::Out>, + pub shape: F::Out>>, + pub strides: F::Out>>, +} + +/// A strided ndarray in NAC3. +/// +/// See IRRT implementation for details about its fields. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDArray; + +impl<'ctx> StructKind<'ctx> for NDArray { + type Fields> = NDArrayFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + data: traversal.add_auto("data"), + itemsize: traversal.add_auto("itemsize"), + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + } + } +} + +/// A NAC3 Python ndarray object. +#[derive(Debug, Clone, Copy)] +pub struct NDArrayObject<'ctx> { + pub dtype: Type, + pub ndims: u64, + pub instance: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap(); + NDArrayObject { dtype, ndims, instance: value } + } +}