Compare commits

..

17 Commits

Author SHA1 Message Date
David Mak c862dbd861 [core] WIP - Implemented construct_* for NDArrays 2024-11-15 15:29:22 +08:00
David Mak 684aafe54c [core] Add itemsize and strides to NDArray struct 2024-11-15 15:28:37 +08:00
David Mak a770d9f415 [core] coregen/types: Implement StructFields for NDArray
Also rename some fields to better align with their naming in numpy.
2024-11-15 15:27:45 +08:00
David Mak 6f702ac250 [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-15 15:19:23 +08:00
David Mak 64ec66d3dd [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-15 15:19:23 +08:00
David Mak 3c336b0ea5 [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Remove __STDC_VERSION__ guard
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-15 15:19:23 +08:00
David Mak 9fab65109a [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-15 15:19:23 +08:00
David Mak 3a5e7a98b1 [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-15 15:19:23 +08:00
lyken d3fb4204e7 core/irrt: fix exception.hpp C++ castings 2024-11-15 15:19:23 +08:00
lyken 8631fc8b58 core/toplevel/helper: add {extract,create}_ndims 2024-11-15 15:19:23 +08:00
David Mak 4b666f8706 [core] codegen/types: Implement StructField{,s}
Loosely based on FieldTraversal.
2024-11-15 15:19:06 +08:00
David Mak 71300f6c86 [core] codegen: Refactor ProxyType and ProxyValue
Accepts generator+context object for generic type checking. Also
implements more default trait impl for easier delegation.
2024-11-11 20:56:18 +08:00
David Mak cd0793f80e [core] Move Proxies to their own modules 2024-11-11 20:45:28 +08:00
David Mak b1adfb245c [core] codegen/classes: Remove Underlying type
This is confusing and we want a better abstraction than this.
2024-11-11 19:09:50 +08:00
David Mak 383989b142 core: WIP - Add tracer runtime 2024-11-11 19:09:48 +08:00
David Mak c92937eca3 [meta] Update pre-commit configuration 2024-11-11 15:00:33 +08:00
David Mak 9e8facf355 [meta] Update cargo dependencies 2024-11-11 15:00:33 +08:00
22 changed files with 1079 additions and 681 deletions

176
Cargo.lock generated
View File

@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.1" version = "1.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" checksum = "40545c26d092346d8a8dab71ee48e7685a7a9cba76e634790c215b41a4a7b4cf"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@ -141,9 +141,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.21" version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -151,9 +151,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.21" version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -175,9 +175,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.3" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
@ -199,9 +199,9 @@ dependencies = [
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.15" version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -282,12 +282,6 @@ dependencies = [
"crypto-common", "crypto-common",
] ]
[[package]]
name = "dissimilar"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59f8e79d1fbf76bdfbde321e902714bf6c49df88a7dda6fc682fc2979226962d"
[[package]] [[package]]
name = "either" name = "either"
version = "1.13.0" version = "1.13.0"
@ -376,12 +370,6 @@ dependencies = [
"wasi", "wasi",
] ]
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.12.3" version = "0.12.3"
@ -559,9 +547,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.164" version = "0.2.162"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398"
[[package]] [[package]]
name = "libloading" name = "libloading"
@ -660,7 +648,6 @@ dependencies = [
"inkwell", "inkwell",
"insta", "insta",
"itertools", "itertools",
"nac3core_derive",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
"rayon", "rayon",
@ -670,18 +657,6 @@ dependencies = [
"test-case", "test-case",
] ]
[[package]]
name = "nac3core_derive"
version = "0.1.0"
dependencies = [
"nac3core",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.87",
"trybuild",
]
[[package]] [[package]]
name = "nac3ld" name = "nac3ld"
version = "0.1.0" version = "0.1.0"
@ -847,30 +822,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.89" version = "1.0.89"
@ -1025,9 +976,9 @@ dependencies = [
[[package]] [[package]]
name = "regex-automata" name = "regex-automata"
version = "0.4.9" version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -1049,9 +1000,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.41" version = "0.38.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"errno", "errno",
@ -1095,18 +1046,18 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.215" version = "1.0.214"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.215" version = "1.0.214"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1115,9 +1066,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.133" version = "1.0.132"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
@ -1125,15 +1076,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_spanned"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_yaml" name = "serde_yaml"
version = "0.8.26" version = "0.8.26"
@ -1257,12 +1199,6 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "target-triple"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.14.0" version = "3.14.0"
@ -1286,15 +1222,6 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "test-case" name = "test-case"
version = "1.2.3" version = "1.2.3"
@ -1328,56 +1255,6 @@ dependencies = [
"syn 2.0.87", "syn 2.0.87",
] ]
[[package]]
name = "toml"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_edit",
]
[[package]]
name = "toml_datetime"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
dependencies = [
"indexmap 2.6.0",
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
]
[[package]]
name = "trybuild"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4"
dependencies = [
"dissimilar",
"glob",
"serde",
"serde_derive",
"serde_json",
"target-triple",
"termcolor",
"toml",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.17.0" version = "1.17.0"
@ -1601,15 +1478,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "yaml-rust" name = "yaml-rust"
version = "0.4.5" version = "0.4.5"

View File

@ -4,7 +4,6 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3core/nac3core_derive",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -1370,6 +1370,7 @@ fn polymorphic_print<'ctx>(
let val = NDArrayValue::from_pointer_value( let val = NDArrayValue::from_pointer_value(
value.into_pointer_value(), value.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );

View File

@ -5,8 +5,6 @@ authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features] [features]
default = ["derive"]
derive = ["dep:nac3core_derive"]
no-escape-analysis = [] no-escape-analysis = []
tracing = [] tracing = []
@ -16,7 +14,6 @@ crossbeam = "0.8"
indexmap = "2.6" indexmap = "2.6"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.10" rayon = "1.10"
nac3core_derive = { path = "nac3core_derive", optional = true }
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26" strum = "0.26"
strum_macros = "0.26" strum_macros = "0.26"

View File

@ -3,3 +3,5 @@
#include "irrt/math.hpp" #include "irrt/math.hpp"
#include "irrt/ndarray.hpp" #include "irrt/ndarray.hpp"
#include "irrt/slice.hpp" #include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"

View File

@ -2,6 +2,8 @@
#include "irrt/int_types.hpp" #include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace { namespace {
template<typename SizeT> template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {

View File

@ -0,0 +1,342 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Assert that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template<typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
*/
template<typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims,
const SizeT* ndarray_shape,
SizeT output_ndims,
const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Return the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template<typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++)
size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template<typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template<typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template<typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The length.
*/
template<typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
if (ndarray->ndims != 0) {
return ndarray->shape[0];
}
// numpy prohibits `__len__` on unsized objects
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
__builtin_unreachable();
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see ndarray's rules for C-contiguity:
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template<typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// References:
// - tinynumpy's implementation:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]:
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity:
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
void* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
void* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template<typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template<typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template<typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy when we see a contiguous segment.
// TODO: Handle overlapping.
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
const int64_t* ndarray_shape,
int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
return len(ndarray);
}
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
return len(ndarray);
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,51 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
/**
* @brief The NDArray object
*
* Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
*
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
* `data`. There are also minor differences in the struct layout.
*/
template<typename SizeT>
struct NDArray {
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values or contain 0.
*/
SizeT* strides;
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
};
} // namespace

View File

@ -1,21 +0,0 @@
[package]
name = "nac3core_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[[test]]
name = "structfields_tests"
path = "tests/structfields_test.rs"
[dev-dependencies]
nac3core = { path = ".." }
trybuild = { version = "1.0", features = ["diff"] }
[dependencies]
proc-macro2 = "1.0"
proc-macro-error = "1.0"
syn = "2.0"
quote = "1.0"

View File

@ -1,231 +0,0 @@
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::quote;
use syn::spanned::Spanned;
use syn::{
parse_macro_input, Data, DataStruct, Expr, ExprPath, GenericArgument, Ident, LitStr,
PathArguments, Type, TypePath,
};
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
///
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
/// `expected_ty_name`, otherwise returns [`None`].
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
return None;
};
let segments = &path.segments;
if segments.len() != 1 {
return None;
};
let segment = segments.iter().next().unwrap();
if segment.ident != expected_ty_name {
return None;
}
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
return Some(Vec::new());
};
let args = &path_args.args;
Some(args.iter().cloned().collect::<Vec<_>>())
}
/// Normalizes a value expression for use when creating an instance of this structure, returning a
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
match &expr {
Expr::Path(ExprPath { qself: None, path, .. }) => {
let Ok(ident) = path.require_ident() else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
};
if ident == "usize" || ident == "size_t" {
let llvm_usize = Ident::new("llvm_usize", ident.span());
quote! { #llvm_usize }
} else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
Expr::Call(_) | Expr::MethodCall(_) => {
quote! { ctx.#expr }
}
_ => {
abort!(
expr,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
}
/// Derives an implementation of `codegen::types::structure::StructFields`.
///
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
///
/// # Prerequisites
///
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
/// `StructFields`.
///
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
/// with either `StructField` or [`PhantomData`] types.
///
/// # Attributes for [`StructFields`]
///
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
/// accepts either an expression returning an instance of `inkwell::types::BasicType` without the
/// `inkwell::context::Context` instance prefix, or the reserved identifiers `usize` and `size_t` referring to an
/// `inkwell::types::IntType` of the platform-dependent integer size.
///
/// # Example
///
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
///
/// ```
/// use nac3core::{
/// codegen::types::structure::StructField,
/// inkwell::{
/// values::{IntValue, PointerValue},
/// AddressSpace,
/// },
/// };
/// use nac3core_derive::StructFields;
///
/// // All classes that implement StructFields must also implement Eq and Copy
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
/// pub struct SliceValue<'ctx> {
/// // Declares ptr have a value type of i8*
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
///
/// // Declares len have a value type of usize, depending on the target compilation platform
/// #[value_type(usize)]
/// len: StructField<'ctx, IntValue<'ctx>>,
/// }
/// ```
#[proc_macro_derive(StructFields, attributes(value_type))]
#[proc_macro_error]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::DeriveInput);
let ident = &input.ident;
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
abort!(input, "Only structs with named fields are supported");
};
if let Err(err_span) =
fields
.iter()
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
{
abort!(err_span, "Only structs with named fields are supported");
};
// Check if struct<'ctx>
if input.generics.params.len() != 1 {
abort!(input.generics, "Expected exactly 1 generic parameter")
}
let phantom_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
.map(|field| field.ident.as_ref().unwrap())
.cloned()
.collect::<Vec<_>>();
let field_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
let ty = &field.ty;
let Some(_) = extract_generic_args("StructField", ty) else {
abort!(field, "Only StructField and PhantomData are allowed")
};
let attrs = &field.attrs;
let Some(value_type_attr) =
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
else {
abort!(field, "Expected #[value_type(...)] attribute for field");
};
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
};
let value_expr_toks = normalize_value_expr(&value_type_expr);
(ident.clone(), value_expr_toks)
})
.collect::<Vec<_>>();
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
let phantoms_create = phantom_info
.iter()
.map(|id| quote! { #id: ::std::marker::PhantomData })
.collect::<Vec<_>>();
let fields_create = field_info
.iter()
.map(|(id, ty)| {
let id_lit = LitStr::new(&id.to_string(), id.span());
quote! {
#id: ::nac3core::codegen::types::structure::StructField::create(
&mut counter,
#id_lit,
#ty,
)
}
})
.collect::<Vec<_>>();
// `.into()` impl of `StructField` for `StructFields::to_vec`
let fields_into =
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
let impl_block = quote! {
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
#ident {
#(#fields_create),*
#(#phantoms_create),*
}
}
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
vec![
#(#fields_into),*
]
}
}
};
impl_block.into()
}

View File

@ -1,9 +0,0 @@
use nac3core_derive::StructFields;
use std::marker::PhantomData;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct EmptyValue<'ctx> {
_phantom: PhantomData<&'ctx ()>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,6 +0,0 @@
#[test]
fn test_parse_empty() {
let t = trybuild::TestCases::new();
t.pass("tests/structfields_empty.rs");
t.pass("tests/structfields_slice.rs");
}

View File

@ -74,6 +74,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
let arg = NDArrayValue::from_pointer_value( let arg = NDArrayValue::from_pointer_value(
arg.into_pointer_value(), arg.into_pointer_value(),
ctx.get_llvm_type(generator, elem_ty), ctx.get_llvm_type(generator, elem_ty),
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -153,7 +154,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int32, ctx.primitives.int32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -216,7 +217,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.int64, ctx.primitives.int64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -295,7 +296,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint32, ctx.primitives.uint32,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
)?; )?;
@ -363,7 +364,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.uint64, ctx.primitives.uint64,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
)?; )?;
@ -430,7 +431,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
)?; )?;
@ -477,7 +478,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -518,7 +519,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
)?; )?;
@ -584,7 +585,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ctx.primitives.bool, ctx.primitives.bool,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| { |generator, ctx, val| {
let elem = call_bool(generator, ctx, (elem_ty, val))?; let elem = call_bool(generator, ctx, (elem_ty, val))?;
@ -639,7 +640,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -690,7 +691,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
@ -921,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx let n_sz_eqz = ctx
@ -1135,7 +1136,7 @@ where
ctx, ctx,
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None),
|generator, ctx, elem_val| { |generator, ctx, elem_val| {
helper_call_numpy_unary_elementwise( helper_call_numpy_unary_elementwise(
generator, generator,
@ -1974,7 +1975,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2016,7 +2017,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unimplemented!("{FN_NAME} operates on float type NdArrays only");
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2066,7 +2067,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2121,7 +2122,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2163,7 +2164,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2206,7 +2207,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2259,7 +2260,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call // Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape( let n2_array = numpy::create_ndarray_const_shape(
generator, generator,
@ -2354,7 +2355,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()
@ -2397,7 +2398,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
let dim0 = unsafe { let dim0 = unsafe {
n1.shape() n1.shape()

View File

@ -1570,12 +1570,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
left_val.into_pointer_value(), left_val.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
None,
llvm_usize, llvm_usize,
None, None,
); );
let right_val = NDArrayValue::from_pointer_value( let right_val = NDArrayValue::from_pointer_value(
right_val.into_pointer_value(), right_val.into_pointer_value(),
llvm_ndarray_dtype2, llvm_ndarray_dtype2,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1631,6 +1633,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let ndarray_val = NDArrayValue::from_pointer_value( let ndarray_val = NDArrayValue::from_pointer_value(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1828,6 +1831,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
let val = NDArrayValue::from_pointer_value( let val = NDArrayValue::from_pointer_value(
val.into_pointer_value(), val.into_pointer_value(),
llvm_ndarray_dtype, llvm_ndarray_dtype,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1926,6 +1930,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let left_val = NDArrayValue::from_pointer_value( let left_val = NDArrayValue::from_pointer_value(
lhs.into_pointer_value(), lhs.into_pointer_value(),
llvm_ndarray_dtype1, llvm_ndarray_dtype1,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -2799,6 +2804,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray = NDArrayValue::from_pointer_value( let ndarray = NDArrayValue::from_pointer_value(
subscripted_ndarray, subscripted_ndarray,
llvm_ndarray_data_t, llvm_ndarray_data_t,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -3542,7 +3548,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
} }

View File

@ -0,0 +1,134 @@
use crate::codegen::{CodeGenContext, CodeGenerator};
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
///
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
#[must_use]
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
// pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndims: Instance<'ctx, Int<SizeT>>,
// shape: Instance<'ctx, Ptr<Int<SizeT>>>,
// ) {
// let name = get_usize_dependent_function_name(
// generator,
// ctx,
// "__nac3_ndarray_util_assert_shape_no_negative",
// );
// FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
// }
//
// pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray_ndims: Instance<'ctx, Int<SizeT>>,
// ndarray_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
// output_ndims: Instance<'ctx, Int<SizeT>>,
// output_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
// ) {
// let name = get_usize_dependent_function_name(
// generator,
// ctx,
// "__nac3_ndarray_util_assert_output_shape_same",
// );
// FnCall::builder(generator, ctx, &name)
// .arg(ndarray_ndims)
// .arg(ndarray_shape)
// .arg(output_ndims)
// .arg(output_shape)
// .returning_void();
// }
//
// pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) -> Instance<'ctx, Int<SizeT>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
// }
//
// pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) -> Instance<'ctx, Int<SizeT>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes")
// }
//
// pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) -> Instance<'ctx, Int<SizeT>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len")
// }
//
// pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) -> Instance<'ctx, Int<Bool>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous")
// }
//
// pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// index: Instance<'ctx, Int<SizeT>>,
// ) -> Instance<'ctx, Ptr<Int<Byte>>> {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement")
// }
//
// pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// indices: Instance<'ctx, Ptr<Int<SizeT>>>,
// ) -> Instance<'ctx, Ptr<Int<Byte>>> {
// let name =
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement")
// }
//
// pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) {
// let name =
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void();
// }
//
// pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
// ) {
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
// FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
// }

View File

@ -15,6 +15,9 @@ use crate::codegen::{
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
pub use basic::*;
mod basic;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size. /// calculated total size.

View File

@ -3,6 +3,7 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools;
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
@ -27,7 +28,7 @@ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, PrimDef}, helper::{arraylike_flatten_element_type, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::unpack_ndarray_var_tys,
DefinitionId, DefinitionId,
}, },
typecheck::{ typecheck::{
@ -43,19 +44,16 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type, elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
.get_llvm_type(generator, ndarray_ty) .as_base_type()
.into_pointer_type()
.get_element_type() .get_element_type()
.into_struct_type(); .into_struct_type();
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None)) Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None))
} }
/// Creates an `NDArray` instance from a dynamic shape. /// Creates an `NDArray` instance from a dynamic shape.
@ -189,28 +187,10 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
// TODO: Disallow dim_sz > u32_MAX // TODO: Disallow dim_sz > u32_MAX
} }
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.shape().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, true),
None,
)
};
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype)
.construct_dyn_shape(generator, ctx, shape, None);
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray) Ok(ndarray)
@ -338,20 +318,24 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
// Get the length/size of the tuple, which also happens to be the value of `ndims`. // Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields(); let ndims = shape_tuple.get_type().count_fields();
let mut shape = Vec::with_capacity(ndims as usize); let shape = (0..ndims)
for dim_i in 0..ndims { .map(|dim_i| {
let dim = ctx ctx.builder
.builder .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) .map(BasicValueEnum::into_int_value)
.unwrap() .map(|v| {
.into_int_value(); ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
})
.unwrap()
})
.collect_vec();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
} }
BasicValueEnum::IntValue(shape_int) => { BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let shape_int =
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
@ -505,6 +489,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -517,6 +502,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -532,6 +518,7 @@ where
let lhs = NDArrayValue::from_pointer_value( let lhs = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -548,6 +535,7 @@ where
let rhs = NDArrayValue::from_pointer_value( let rhs = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -706,7 +694,8 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
{ {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
.load_ndims(ctx)
} }
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
@ -856,7 +845,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_representable(object, llvm_usize).is_ok() { if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
let ndarray = gen_if_else_expr_callback( let ndarray = gen_if_else_expr_callback(
generator, generator,
@ -932,6 +921,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
return Ok(NDArrayValue::from_pointer_value( return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
)); ));
@ -1465,6 +1455,7 @@ where
let lhs_val = NDArrayValue::from_pointer_value( let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(), lhs_val.into_pointer_value(),
llvm_lhs_elem_ty, llvm_lhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1473,6 +1464,7 @@ where
let rhs_val = NDArrayValue::from_pointer_value( let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(), rhs_val.into_pointer_value(),
llvm_rhs_elem_ty, llvm_rhs_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -1499,6 +1491,7 @@ where
let ndarray = NDArrayValue::from_pointer_value( let ndarray = NDArrayValue::from_pointer_value(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
); );
@ -2061,6 +2054,7 @@ pub fn gen_ndarray_copy<'ctx>(
NDArrayValue::from_pointer_value( NDArrayValue::from_pointer_value(
this_arg.into_pointer_value(), this_arg.into_pointer_value(),
llvm_elem_ty, llvm_elem_ty,
None,
llvm_usize, llvm_usize,
None, None,
), ),
@ -2098,7 +2092,7 @@ pub fn gen_ndarray_fill<'ctx>(
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
context, context,
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|generator, ctx, _| { |generator, ctx, _| {
let value = if value_arg.is_pointer_value() { let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
@ -2140,7 +2134,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
@ -2260,7 +2254,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2548,8 +2542,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));

View File

@ -1,6 +1,5 @@
use inkwell::context::ContextRef;
use inkwell::{ use inkwell::{
context::{AsContextRef, Context}, context::{AsContextRef, Context, ContextRef},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue}, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
@ -11,9 +10,13 @@ use super::{
structure::{FieldIndexCounter, StructField, StructFields}, structure::{FieldIndexCounter, StructField, StructFields},
ProxyType, ProxyType,
}; };
use crate::codegen::{ use crate::{
values::{ArraySliceValue, NDArrayValue, ProxyValue}, codegen::{
{CodeGenContext, CodeGenerator}, values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
{CodeGenContext, CodeGenerator},
},
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type,
}; };
/// Proxy type for a `ndarray` type in LLVM. /// Proxy type for a `ndarray` type in LLVM.
@ -21,13 +24,17 @@ use crate::codegen::{
pub struct NDArrayType<'ctx> { pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>, ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
// TODO(Derppening): Make this non-optional
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
} }
#[derive(PartialEq, Eq, Clone, Copy)] #[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> { pub struct NDArrayStructFields<'ctx> {
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>, pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>, pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>,
pub data: StructField<'ctx, PointerValue<'ctx>>, pub data: StructField<'ctx, PointerValue<'ctx>>,
} }
@ -37,12 +44,18 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
let mut counter = FieldIndexCounter::default(); let mut counter = FieldIndexCounter::default();
NDArrayStructFields { NDArrayStructFields {
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize), ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create( shape: StructField::create(
&mut counter, &mut counter,
"shape", "shape",
llvm_usize.ptr_type(AddressSpace::default()), llvm_usize.ptr_type(AddressSpace::default()),
), ),
strides: StructField::create(
&mut counter,
"strides",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create( data: StructField::create(
&mut counter, &mut counter,
"data", "data",
@ -52,7 +65,13 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
} }
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![self.ndims.into(), self.shape.into(), self.data.into()] vec![
self.itemsize.into(),
self.ndims.into(),
self.shape.into(),
self.strides.into(),
self.data.into(),
]
} }
} }
@ -62,70 +81,45 @@ impl<'ctx> NDArrayType<'ctx> {
llvm_ty: PointerType<'ctx>, llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type(); let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
}; };
if llvm_ndarray_ty.count_fields() != 3 { if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!( return Err(format!(
"Expected 3 fields in `NDArray`, got {}", "Expected {} fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields() llvm_ndarray_ty.count_fields()
)); ));
} }
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); llvm_expected_ty
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { .iter()
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); .enumerate()
}; .map(|(i, expected_ty)| {
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { (expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
return Err(format!( })
"Expected {}-bit int type for `ndarray.0`, got {}-bit int", .try_for_each(|(expected_ty, actual_ty)| {
llvm_usize.get_bit_width(), if expected_ty == actual_ty {
ndarray_ndims_ty.get_bit_width() Ok(())
)); } else {
} Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); })?;
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
let ndarray_data = ndarray_pdata.get_element_type();
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
));
};
if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType // TODO: Move this into e.g. StructProxyType
#[must_use] #[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize) NDArrayStructFields::new(ctx, llvm_usize)
} }
@ -133,7 +127,7 @@ impl<'ctx> NDArrayType<'ctx> {
#[must_use] #[must_use]
pub fn get_fields( pub fn get_fields(
&self, &self,
ctx: &'ctx Context, ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> { ) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize) Self::fields(ctx, llvm_usize)
@ -142,7 +136,7 @@ impl<'ctx> NDArrayType<'ctx> {
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use] #[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
// //
// * data : Pointer to an array containing the array data // * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes // * itemsize: The size of each NDArray elements in bytes
@ -165,7 +159,28 @@ impl<'ctx> NDArrayType<'ctx> {
let llvm_usize = generator.get_size_type(ctx); let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize } NDArrayType { ty: llvm_ndarray, dtype, ndims: None, llvm_usize }
}
/// Creates an [`NDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
NDArrayType {
ty: Self::llvm_type(ctx.ctx, llvm_usize),
dtype: llvm_dtype,
ndims: Some(ndims),
llvm_usize,
}
} }
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
@ -177,7 +192,7 @@ impl<'ctx> NDArrayType<'ctx> {
) -> Self { ) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, llvm_usize } NDArrayType { ty: ptr_ty, dtype, ndims: None, llvm_usize }
} }
/// Returns the type of the `size` field of this `ndarray` type. /// Returns the type of the `size` field of this `ndarray` type.
@ -186,7 +201,7 @@ impl<'ctx> NDArrayType<'ctx> {
self.as_base_type() self.as_base_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type()
.get_field_type_at_index(0) .get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type) .map(BasicTypeEnum::into_int_type)
.unwrap() .unwrap()
} }
@ -196,6 +211,114 @@ impl<'ctx> NDArrayType<'ctx> {
pub fn element_type(&self) -> BasicTypeEnum<'ctx> { pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype self.dtype
} }
/// Returns the number of dimensions represented by this [`NDArrayType`], or [`None`] if it is
/// not known.
#[must_use]
pub fn ndims_as_value(&self) -> Option<IntValue<'ctx>> {
self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false))
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated onto the stack.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
#[must_use]
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Option<u64>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.new_value(generator, ctx, name);
let itemsize = ctx
.builder
.build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap();
ndarray.store_itemsize(ctx, generator, itemsize);
let ndims_val = self.llvm_usize.const_int(ndims.or(self.ndims).unwrap(), false);
ndarray.store_ndims(ctx, generator, ndims_val);
ndarray.create_shape(ctx, self.llvm_usize, ndims_val);
ndarray.create_strides(ctx, self.llvm_usize, ndims_val);
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[u64],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
let dim = self.llvm_usize.const_int(*dim, false);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
dim,
);
}
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[IntValue<'ctx>],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
assert_eq!(
dim.get_type(),
self.llvm_usize,
"Expected {} but got {}",
self.llvm_usize.print_to_string(),
dim.get_type().print_to_string()
);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(i as u64, false),
*dim,
);
}
}
ndarray
}
} }
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
@ -264,7 +387,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
) -> Self::Value { ) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type()); debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name) NDArrayValue::from_pointer_value(value, self.dtype, self.ndims, self.llvm_usize, name)
} }
fn as_base_type(&self) -> Self::Base { fn as_base_type(&self) -> Self::Base {

View File

@ -86,7 +86,7 @@ where
/// index. /// index.
/// * `name` - Name of the field. /// * `name` - Name of the field.
/// * `ty` - The type of this field. /// * `ty` - The type of this field.
pub fn create( pub(super) fn create(
idx_counter: &mut FieldIndexCounter, idx_counter: &mut FieldIndexCounter,
name: &'static str, name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>, ty: impl Into<BasicTypeEnum<'ctx>>,
@ -99,7 +99,11 @@ where
/// * `index` - The index of this field within its enclosing structure. /// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field. /// * `name` - Name of the field.
/// * `ty` - The type of this field. /// * `ty` - The type of this field.
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self { pub(super) fn create_at(
index: u32,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData } StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
} }
@ -189,7 +193,7 @@ where
/// A counter that tracks the next index of a field using a monotonically increasing counter. /// A counter that tracks the next index of a field using a monotonically increasing counter.
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)] #[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
pub struct FieldIndexCounter(u32); pub(super) struct FieldIndexCounter(u32);
impl FieldIndexCounter { impl FieldIndexCounter {
/// Increments the number stored by this counter, returning the previous value. /// Increments the number stored by this counter, returning the previous value.

View File

@ -21,6 +21,7 @@ use crate::codegen::{
pub struct NDArrayValue<'ctx> { pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>, value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
} }
@ -40,12 +41,13 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn from_pointer_value( pub fn from_pointer_value(
ptr: PointerValue<'ctx>, ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Self { ) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, llvm_usize, name } NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
@ -75,6 +77,33 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
} }
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.itemsize
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling /// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field. /// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
@ -105,6 +134,36 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayShapeProxy(self) NDArrayShapeProxy(self)
} }
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.strides
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
@ -168,103 +227,6 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
} }
} }
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. /// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
@ -515,3 +477,197 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
for NDArrayDataProxy<'ctx, '_> for NDArrayDataProxy<'ctx, '_>
{ {
} }
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}

View File

@ -1759,14 +1759,14 @@ def run() -> int32:
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_cholesky() # test_ndarray_cholesky()
test_ndarray_qr() # test_ndarray_qr()
test_ndarray_svd() # test_ndarray_svd()
test_ndarray_linalg_inv() # test_ndarray_linalg_inv()
test_ndarray_pinv() # test_ndarray_pinv()
test_ndarray_matrix_power() # test_ndarray_matrix_power()
test_ndarray_det() # test_ndarray_det()
test_ndarray_lu() # test_ndarray_lu()
test_ndarray_schur() # test_ndarray_schur()
test_ndarray_hessenberg() # test_ndarray_hessenberg()
return 0 return 0