Compare commits
16 Commits
c862dbd861
...
41fb27c913
Author | SHA1 | Date |
---|---|---|
David Mak | 41fb27c913 | |
David Mak | b3b61959dd | |
David Mak | 93c4cf74cd | |
David Mak | 2ffdedb6ec | |
David Mak | e2a775c7b0 | |
David Mak | 4ce1a3f4e8 | |
lyken | 923690beba | |
lyken | 4093a334cf | |
David Mak | ffce2fc02a | |
David Mak | f0bd8ea38c | |
David Mak | 84c165d008 | |
David Mak | c2f29e42d6 | |
David Mak | 3111afae8f | |
David Mak | e2bbedbec9 | |
David Mak | 7e43a37e7d | |
David Mak | d61f585f5a |
|
@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.37"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40545c26d092346d8a8dab71ee48e7685a7a9cba76e634790c215b41a4a7b4cf"
|
||||
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
@ -141,9 +141,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
|||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.20"
|
||||
version = "4.5.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
|
||||
checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
|
@ -151,9 +151,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.20"
|
||||
version = "4.5.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
|
||||
checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
|
@ -175,9 +175,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.2"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
|
||||
checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
|
@ -199,9 +199,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.14"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
|
||||
checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
@ -282,6 +282,12 @@ dependencies = [
|
|||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dissimilar"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59f8e79d1fbf76bdfbde321e902714bf6c49df88a7dda6fc682fc2979226962d"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
|
@ -370,6 +376,12 @@ dependencies = [
|
|||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
|
@ -547,9 +559,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
|||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.162"
|
||||
version = "0.2.164"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398"
|
||||
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
|
@ -648,6 +660,7 @@ dependencies = [
|
|||
"inkwell",
|
||||
"insta",
|
||||
"itertools",
|
||||
"nac3core_derive",
|
||||
"nac3parser",
|
||||
"parking_lot",
|
||||
"rayon",
|
||||
|
@ -657,6 +670,18 @@ dependencies = [
|
|||
"test-case",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nac3core_derive"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"nac3core",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
"trybuild",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nac3ld"
|
||||
version = "0.1.0"
|
||||
|
@ -822,6 +847,30 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.89"
|
||||
|
@ -976,9 +1025,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.8"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
|
@ -1000,9 +1049,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.40"
|
||||
version = "0.38.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0"
|
||||
checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
|
@ -1046,18 +1095,18 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.214"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5"
|
||||
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.214"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766"
|
||||
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1066,9 +1115,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.132"
|
||||
version = "1.0.133"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
|
||||
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
|
@ -1076,6 +1125,15 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.8.26"
|
||||
|
@ -1199,6 +1257,12 @@ version = "0.12.16"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "target-triple"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.14.0"
|
||||
|
@ -1222,6 +1286,15 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-case"
|
||||
version = "1.2.3"
|
||||
|
@ -1255,6 +1328,56 @@ dependencies = [
|
|||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
|
||||
dependencies = [
|
||||
"indexmap 2.6.0",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trybuild"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4"
|
||||
dependencies = [
|
||||
"dissimilar",
|
||||
"glob",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"target-triple",
|
||||
"termcolor",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
@ -1478,6 +1601,15 @@ version = "0.52.6"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yaml-rust"
|
||||
version = "0.4.5"
|
||||
|
|
|
@ -4,6 +4,7 @@ members = [
|
|||
"nac3ast",
|
||||
"nac3parser",
|
||||
"nac3core",
|
||||
"nac3core/nac3core_derive",
|
||||
"nac3standalone",
|
||||
"nac3artiq",
|
||||
"runkernel",
|
||||
|
|
|
@ -1370,7 +1370,6 @@ fn polymorphic_print<'ctx>(
|
|||
let val = NDArrayValue::from_pointer_value(
|
||||
value.into_pointer_value(),
|
||||
llvm_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
|
|
@ -5,6 +5,8 @@ authors = ["M-Labs"]
|
|||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["derive"]
|
||||
derive = ["dep:nac3core_derive"]
|
||||
no-escape-analysis = []
|
||||
tracing = []
|
||||
|
||||
|
@ -14,6 +16,7 @@ crossbeam = "0.8"
|
|||
indexmap = "2.6"
|
||||
parking_lot = "0.12"
|
||||
rayon = "1.10"
|
||||
nac3core_derive = { path = "nac3core_derive", optional = true }
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
|
|
|
@ -3,5 +3,3 @@
|
|||
#include "irrt/math.hpp"
|
||||
#include "irrt/ndarray.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
// TODO: To be deleted since NDArray with strides is done.
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||
|
|
|
@ -1,342 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace basic {
|
||||
/**
|
||||
* @brief Assert that `shape` does not contain negative dimensions.
|
||||
*
|
||||
* @param ndims Number of dimensions in `shape`
|
||||
* @param shape The shape to check on
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
if (shape[axis] < 0) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"negative dimensions are not allowed; axis {0} "
|
||||
"has dimension {1}",
|
||||
axis, shape[axis], NO_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void assert_output_shape_same(SizeT ndarray_ndims,
|
||||
const SizeT* ndarray_shape,
|
||||
SizeT output_ndims,
|
||||
const SizeT* output_shape) {
|
||||
if (ndarray_ndims != output_ndims) {
|
||||
// There is no corresponding NumPy error message like this.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
||||
output_ndims, ndarray_ndims, NO_PARAM);
|
||||
}
|
||||
|
||||
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
||||
if (ndarray_shape[axis] != output_shape[axis]) {
|
||||
// There is no corresponding NumPy error message like this.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"Mismatched dimensions on axis {0}, output has "
|
||||
"dimension {1}, but destination ndarray has dimension {2}.",
|
||||
axis, output_shape[axis], ndarray_shape[axis]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the number of elements of an ndarray given its shape.
|
||||
*
|
||||
* @param ndims Number of dimensions in `shape`
|
||||
* @param shape The shape of the ndarray
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||
SizeT size = 1;
|
||||
for (SizeT axis = 0; axis < ndims; axis++)
|
||||
size *= shape[axis];
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
||||
*
|
||||
* @param ndims Number of elements in `shape` and `indices`
|
||||
* @param shape The shape of the ndarray
|
||||
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
||||
* @param nth The index of the element of interest.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
SizeT axis = ndims - i - 1;
|
||||
SizeT dim = shape[axis];
|
||||
|
||||
indices[axis] = nth % dim;
|
||||
nth /= dim;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the number of elements of an `ndarray`
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.size`
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT size(const NDArray<SizeT>* ndarray) {
|
||||
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return of the number of its content of an `ndarray`.
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.nbytes`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
||||
return size(ndarray) * ndarray->itemsize;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.__len__`.
|
||||
*
|
||||
* @param dst_length The length.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT len(const NDArray<SizeT>* ndarray) {
|
||||
if (ndarray->ndims != 0) {
|
||||
return ndarray->shape[0];
|
||||
}
|
||||
|
||||
// numpy prohibits `__len__` on unsized objects
|
||||
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
||||
*
|
||||
* You may want to see ndarray's rules for C-contiguity:
|
||||
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||
*/
|
||||
template<typename SizeT>
|
||||
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
||||
// References:
|
||||
// - tinynumpy's implementation:
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
||||
// - ndarray's flags["C_CONTIGUOUS"]:
|
||||
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
||||
// - ndarray's rules for C-contiguity:
|
||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||
|
||||
// From
|
||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
||||
//
|
||||
// The traditional rule is that for an array to be flagged as C contiguous,
|
||||
// the following must hold:
|
||||
//
|
||||
// strides[-1] == itemsize
|
||||
// strides[i] == shape[i+1] * strides[i + 1]
|
||||
// [...]
|
||||
// According to these rules, a 0- or 1-dimensional array is either both
|
||||
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
||||
// can be C- or F- contiguous, or neither, but not both. Though there
|
||||
// there are exceptions for arrays with zero or one item, in the first
|
||||
// case the check is relaxed up to and including the first dimension
|
||||
// with shape[i] == 0. In the second case `strides == itemsize` will
|
||||
// can be true for all dimensions and both flags are set.
|
||||
|
||||
if (ndarray->ndims == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
||||
SizeT axis_i = ndarray->ndims - i - 1;
|
||||
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
||||
*
|
||||
* This function does no bound check.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
||||
void* element = ndarray->data;
|
||||
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
||||
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
||||
return element;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
||||
*
|
||||
* This function does no bound check.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
||||
void* element = ndarray->data;
|
||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||
SizeT axis = ndarray->ndims - i - 1;
|
||||
SizeT dim = ndarray->shape[axis];
|
||||
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
||||
nth /= dim;
|
||||
}
|
||||
return element;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
||||
*
|
||||
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||
SizeT stride_product = 1;
|
||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||
SizeT axis = ndarray->ndims - i - 1;
|
||||
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
||||
stride_product *= ndarray->shape[axis];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set an element in `ndarray`.
|
||||
*
|
||||
* @param pelement Pointer to the element in `ndarray` to be set.
|
||||
* @param pvalue Pointer to the value `pelement` will be set to.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
||||
*
|
||||
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
||||
// TODO: Handle overlapping.
|
||||
|
||||
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
||||
|
||||
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||
}
|
||||
}
|
||||
} // namespace basic
|
||||
} // namespace ndarray
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::basic;
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
||||
assert_shape_no_negative(ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
||||
assert_shape_no_negative(ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
||||
const int32_t* ndarray_shape,
|
||||
int32_t output_ndims,
|
||||
const int32_t* output_shape) {
|
||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
||||
const int64_t* ndarray_shape,
|
||||
int64_t output_ndims,
|
||||
const int64_t* output_shape) {
|
||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
||||
return len(ndarray);
|
||||
}
|
||||
|
||||
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
||||
return len(ndarray);
|
||||
}
|
||||
|
||||
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
||||
return is_c_contiguous(ndarray);
|
||||
}
|
||||
|
||||
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
||||
return is_c_contiguous(ndarray);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
||||
return get_nth_pelement(ndarray, nth);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
||||
return get_nth_pelement(ndarray, nth);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
||||
return get_pelement_by_indices(ndarray, indices);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||
return get_pelement_by_indices(ndarray, indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief The NDArray object
|
||||
*
|
||||
* Official numpy implementation:
|
||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
||||
*
|
||||
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
||||
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
||||
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
||||
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
||||
* `data`. There are also minor differences in the struct layout.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
struct NDArray {
|
||||
/**
|
||||
* @brief The number of bytes of a single element in `data`.
|
||||
*/
|
||||
SizeT itemsize;
|
||||
|
||||
/**
|
||||
* @brief The number of dimensions of this shape.
|
||||
*/
|
||||
SizeT ndims;
|
||||
|
||||
/**
|
||||
* @brief The NDArray shape, with length equal to `ndims`.
|
||||
*
|
||||
* Note that it may contain 0.
|
||||
*/
|
||||
SizeT* shape;
|
||||
|
||||
/**
|
||||
* @brief Array strides, with length equal to `ndims`
|
||||
*
|
||||
* The stride values are in units of bytes, not number of elements.
|
||||
*
|
||||
* Note that `strides` can have negative values or contain 0.
|
||||
*/
|
||||
SizeT* strides;
|
||||
|
||||
/**
|
||||
* @brief The underlying data this `ndarray` is pointing to.
|
||||
*/
|
||||
void* data;
|
||||
};
|
||||
} // namespace
|
|
@ -0,0 +1,21 @@
|
|||
[package]
|
||||
name = "nac3core_derive"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[[test]]
|
||||
name = "structfields_tests"
|
||||
path = "tests/structfields_test.rs"
|
||||
|
||||
[dev-dependencies]
|
||||
nac3core = { path = ".." }
|
||||
trybuild = { version = "1.0", features = ["diff"] }
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0"
|
||||
proc-macro-error = "1.0"
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
|
@ -0,0 +1,231 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro_error::{abort, proc_macro_error};
|
||||
use quote::quote;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{
|
||||
parse_macro_input, Data, DataStruct, Expr, ExprPath, GenericArgument, Ident, LitStr,
|
||||
PathArguments, Type, TypePath,
|
||||
};
|
||||
|
||||
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
|
||||
///
|
||||
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
|
||||
/// `expected_ty_name`, otherwise returns [`None`].
|
||||
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
|
||||
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let segments = &path.segments;
|
||||
if segments.len() != 1 {
|
||||
return None;
|
||||
};
|
||||
|
||||
let segment = segments.iter().next().unwrap();
|
||||
if segment.ident != expected_ty_name {
|
||||
return None;
|
||||
}
|
||||
|
||||
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
|
||||
return Some(Vec::new());
|
||||
};
|
||||
let args = &path_args.args;
|
||||
|
||||
Some(args.iter().cloned().collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Normalizes a value expression for use when creating an instance of this structure, returning a
|
||||
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
|
||||
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
|
||||
match &expr {
|
||||
Expr::Path(ExprPath { qself: None, path, .. }) => {
|
||||
let Ok(ident) = path.require_ident() else {
|
||||
abort!(
|
||||
path,
|
||||
format!(
|
||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||
quote!(#expr).to_string(),
|
||||
)
|
||||
)
|
||||
};
|
||||
|
||||
if ident == "usize" || ident == "size_t" {
|
||||
let llvm_usize = Ident::new("llvm_usize", ident.span());
|
||||
quote! { #llvm_usize }
|
||||
} else {
|
||||
abort!(
|
||||
path,
|
||||
format!(
|
||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||
quote!(#expr).to_string(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Expr::Call(_) | Expr::MethodCall(_) => {
|
||||
quote! { ctx.#expr }
|
||||
}
|
||||
|
||||
_ => {
|
||||
abort!(
|
||||
expr,
|
||||
format!(
|
||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||
quote!(#expr).to_string(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Derives an implementation of `codegen::types::structure::StructFields`.
|
||||
///
|
||||
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
|
||||
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
|
||||
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
|
||||
///
|
||||
/// # Prerequisites
|
||||
///
|
||||
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
|
||||
/// `StructFields`.
|
||||
///
|
||||
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
|
||||
/// with either `StructField` or [`PhantomData`] types.
|
||||
///
|
||||
/// # Attributes for [`StructFields`]
|
||||
///
|
||||
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
|
||||
/// accepts either an expression returning an instance of `inkwell::types::BasicType` without the
|
||||
/// `inkwell::context::Context` instance prefix, or the reserved identifiers `usize` and `size_t` referring to an
|
||||
/// `inkwell::types::IntType` of the platform-dependent integer size.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
|
||||
///
|
||||
/// ```
|
||||
/// use nac3core::{
|
||||
/// codegen::types::structure::StructField,
|
||||
/// inkwell::{
|
||||
/// values::{IntValue, PointerValue},
|
||||
/// AddressSpace,
|
||||
/// },
|
||||
/// };
|
||||
/// use nac3core_derive::StructFields;
|
||||
///
|
||||
/// // All classes that implement StructFields must also implement Eq and Copy
|
||||
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
/// pub struct SliceValue<'ctx> {
|
||||
/// // Declares ptr have a value type of i8*
|
||||
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
///
|
||||
/// // Declares len have a value type of usize, depending on the target compilation platform
|
||||
/// #[value_type(usize)]
|
||||
/// len: StructField<'ctx, IntValue<'ctx>>,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(StructFields, attributes(value_type))]
|
||||
#[proc_macro_error]
|
||||
pub fn derive(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as syn::DeriveInput);
|
||||
let ident = &input.ident;
|
||||
|
||||
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
|
||||
abort!(input, "Only structs with named fields are supported");
|
||||
};
|
||||
if let Err(err_span) =
|
||||
fields
|
||||
.iter()
|
||||
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
|
||||
{
|
||||
abort!(err_span, "Only structs with named fields are supported");
|
||||
};
|
||||
|
||||
// Check if struct<'ctx>
|
||||
if input.generics.params.len() != 1 {
|
||||
abort!(input.generics, "Expected exactly 1 generic parameter")
|
||||
}
|
||||
|
||||
let phantom_info = fields
|
||||
.iter()
|
||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
|
||||
.map(|field| field.ident.as_ref().unwrap())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let field_info = fields
|
||||
.iter()
|
||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
|
||||
.map(|field| {
|
||||
let ident = field.ident.as_ref().unwrap();
|
||||
let ty = &field.ty;
|
||||
|
||||
let Some(_) = extract_generic_args("StructField", ty) else {
|
||||
abort!(field, "Only StructField and PhantomData are allowed")
|
||||
};
|
||||
|
||||
let attrs = &field.attrs;
|
||||
let Some(value_type_attr) =
|
||||
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
|
||||
else {
|
||||
abort!(field, "Expected #[value_type(...)] attribute for field");
|
||||
};
|
||||
|
||||
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
|
||||
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
|
||||
};
|
||||
|
||||
let value_expr_toks = normalize_value_expr(&value_type_expr);
|
||||
|
||||
(ident.clone(), value_expr_toks)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
|
||||
let phantoms_create = phantom_info
|
||||
.iter()
|
||||
.map(|id| quote! { #id: ::std::marker::PhantomData })
|
||||
.collect::<Vec<_>>();
|
||||
let fields_create = field_info
|
||||
.iter()
|
||||
.map(|(id, ty)| {
|
||||
let id_lit = LitStr::new(&id.to_string(), id.span());
|
||||
quote! {
|
||||
#id: ::nac3core::codegen::types::structure::StructField::create(
|
||||
&mut counter,
|
||||
#id_lit,
|
||||
#ty,
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// `.into()` impl of `StructField` for `StructFields::to_vec`
|
||||
let fields_into =
|
||||
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
|
||||
|
||||
let impl_block = quote! {
|
||||
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
|
||||
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
|
||||
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
|
||||
|
||||
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
|
||||
|
||||
#ident {
|
||||
#(#fields_create),*
|
||||
#(#phantoms_create),*
|
||||
}
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
|
||||
vec![
|
||||
#(#fields_into),*
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
impl_block.into()
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
use nac3core_derive::StructFields;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct EmptyValue<'ctx> {
|
||||
_phantom: PhantomData<&'ctx ()>,
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,18 @@
|
|||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceValue<'ctx> {
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,6 @@
|
|||
#[test]
|
||||
fn test_parse_empty() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.pass("tests/structfields_empty.rs");
|
||||
t.pass("tests/structfields_slice.rs");
|
||||
}
|
|
@ -74,7 +74,6 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let arg = NDArrayValue::from_pointer_value(
|
||||
arg.into_pointer_value(),
|
||||
ctx.get_llvm_type(generator, elem_ty),
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -154,7 +153,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.int32,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -217,7 +216,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.int64,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -296,7 +295,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.uint32,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -364,7 +363,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.uint64,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -431,7 +430,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.float,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -478,7 +477,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ret_elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||
)?;
|
||||
|
||||
|
@ -519,7 +518,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.float,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
||||
)?;
|
||||
|
||||
|
@ -585,7 +584,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
||||
|
||||
|
@ -640,7 +639,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ret_elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||
)?;
|
||||
|
||||
|
@ -691,7 +690,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
ret_elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||
)?;
|
||||
|
||||
|
@ -922,7 +921,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None);
|
||||
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx
|
||||
|
@ -1136,7 +1135,7 @@ where
|
|||
ctx,
|
||||
ret_elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, elem_val| {
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
|
@ -1975,7 +1974,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
|
@ -2017,7 +2016,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
|
@ -2067,7 +2066,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
|
@ -2122,7 +2121,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
|
@ -2164,7 +2163,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
|
@ -2207,7 +2206,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
|
@ -2260,7 +2259,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let n2_array = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
|
@ -2355,7 +2354,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
|
@ -2398,7 +2397,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.shape()
|
||||
|
|
|
@ -1570,14 +1570,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let left_val = NDArrayValue::from_pointer_value(
|
||||
left_val.into_pointer_value(),
|
||||
llvm_ndarray_dtype1,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
let right_val = NDArrayValue::from_pointer_value(
|
||||
right_val.into_pointer_value(),
|
||||
llvm_ndarray_dtype2,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -1633,7 +1631,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let ndarray_val = NDArrayValue::from_pointer_value(
|
||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||
llvm_ndarray_dtype,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -1831,7 +1828,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let val = NDArrayValue::from_pointer_value(
|
||||
val.into_pointer_value(),
|
||||
llvm_ndarray_dtype,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -1930,7 +1926,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let left_val = NDArrayValue::from_pointer_value(
|
||||
lhs.into_pointer_value(),
|
||||
llvm_ndarray_dtype1,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -2804,7 +2799,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
let ndarray = NDArrayValue::from_pointer_value(
|
||||
subscripted_ndarray,
|
||||
llvm_ndarray_data_t,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -3548,7 +3542,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None);
|
||||
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None);
|
||||
|
||||
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||
}
|
||||
|
|
|
@ -15,9 +15,6 @@ use crate::codegen::{
|
|||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
pub use basic::*;
|
||||
|
||||
mod basic;
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||
/// calculated total size.
|
|
@ -1,134 +0,0 @@
|
|||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
|
||||
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
||||
///
|
||||
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
||||
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
||||
#[must_use]
|
||||
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
name: &str,
|
||||
) -> String {
|
||||
let mut name = name.to_owned();
|
||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => {}
|
||||
64 => name.push_str("64"),
|
||||
bit_width => {
|
||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
|
||||
// pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndims: Instance<'ctx, Int<SizeT>>,
|
||||
// shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
// ) {
|
||||
// let name = get_usize_dependent_function_name(
|
||||
// generator,
|
||||
// ctx,
|
||||
// "__nac3_ndarray_util_assert_shape_no_negative",
|
||||
// );
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray_ndims: Instance<'ctx, Int<SizeT>>,
|
||||
// ndarray_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
// output_ndims: Instance<'ctx, Int<SizeT>>,
|
||||
// output_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
// ) {
|
||||
// let name = get_usize_dependent_function_name(
|
||||
// generator,
|
||||
// ctx,
|
||||
// "__nac3_ndarray_util_assert_output_shape_same",
|
||||
// );
|
||||
// FnCall::builder(generator, ctx, &name)
|
||||
// .arg(ndarray_ndims)
|
||||
// .arg(ndarray_shape)
|
||||
// .arg(output_ndims)
|
||||
// .arg(output_shape)
|
||||
// .returning_void();
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) -> Instance<'ctx, Int<SizeT>> {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) -> Instance<'ctx, Int<SizeT>> {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) -> Instance<'ctx, Int<SizeT>> {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) -> Instance<'ctx, Int<Bool>> {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// index: Instance<'ctx, Int<SizeT>>,
|
||||
// ) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||
// ) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
||||
// let name =
|
||||
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement")
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) {
|
||||
// let name =
|
||||
// get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||
// FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void();
|
||||
// }
|
||||
//
|
||||
// pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// generator: &mut G,
|
||||
// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
// src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||
// ) {
|
||||
// let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||
// FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
||||
// }
|
|
@ -3,7 +3,6 @@ use inkwell::{
|
|||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3parser::ast::{Operator, StrRef};
|
||||
|
||||
|
@ -28,7 +27,7 @@ use crate::{
|
|||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
helper::{arraylike_flatten_element_type, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId,
|
||||
},
|
||||
typecheck::{
|
||||
|
@ -44,16 +43,19 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|||
elem_ty: Type,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
|
||||
.as_base_type()
|
||||
let llvm_ndarray_t = ctx
|
||||
.get_llvm_type(generator, ndarray_ty)
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
.into_struct_type();
|
||||
|
||||
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||
|
||||
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None))
|
||||
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
|
||||
}
|
||||
|
||||
/// Creates an `NDArray` instance from a dynamic shape.
|
||||
|
@ -187,10 +189,28 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
// TODO: Disallow dim_sz > u32_MAX
|
||||
}
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
|
||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||
|
||||
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
for (i, &shape_dim) in shape.iter().enumerate() {
|
||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||
let ndarray_dim = unsafe {
|
||||
ndarray.shape().ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, true),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
|
||||
}
|
||||
|
||||
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype)
|
||||
.construct_dyn_shape(generator, ctx, shape, None);
|
||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||
|
||||
Ok(ndarray)
|
||||
|
@ -318,24 +338,20 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||
let ndims = shape_tuple.get_type().count_fields();
|
||||
|
||||
let shape = (0..ndims)
|
||||
.map(|dim_i| {
|
||||
ctx.builder
|
||||
let mut shape = Vec::with_capacity(ndims as usize);
|
||||
for dim_i in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.map(|v| {
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
|
||||
})
|
||||
.unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
.into_int_value();
|
||||
|
||||
shape.push(dim);
|
||||
}
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||
}
|
||||
BasicValueEnum::IntValue(shape_int) => {
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
let shape_int =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
|
||||
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||
}
|
||||
|
@ -489,7 +505,6 @@ where
|
|||
let lhs_val = NDArrayValue::from_pointer_value(
|
||||
lhs_val.into_pointer_value(),
|
||||
llvm_lhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -502,7 +517,6 @@ where
|
|||
let rhs_val = NDArrayValue::from_pointer_value(
|
||||
rhs_val.into_pointer_value(),
|
||||
llvm_rhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -518,7 +532,6 @@ where
|
|||
let lhs = NDArrayValue::from_pointer_value(
|
||||
lhs_val.into_pointer_value(),
|
||||
llvm_lhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -535,7 +548,6 @@ where
|
|||
let rhs = NDArrayValue::from_pointer_value(
|
||||
rhs_val.into_pointer_value(),
|
||||
llvm_rhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -694,8 +706,7 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|||
{
|
||||
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
||||
NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
|
||||
.load_ndims(ctx)
|
||||
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
|
||||
}
|
||||
|
||||
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
||||
|
@ -845,7 +856,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
||||
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
|
||||
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
|
||||
|
||||
let ndarray = gen_if_else_expr_callback(
|
||||
generator,
|
||||
|
@ -921,7 +932,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
return Ok(NDArrayValue::from_pointer_value(
|
||||
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
||||
llvm_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
));
|
||||
|
@ -1455,7 +1465,6 @@ where
|
|||
let lhs_val = NDArrayValue::from_pointer_value(
|
||||
lhs_val.into_pointer_value(),
|
||||
llvm_lhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -1464,7 +1473,6 @@ where
|
|||
let rhs_val = NDArrayValue::from_pointer_value(
|
||||
rhs_val.into_pointer_value(),
|
||||
llvm_rhs_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -1491,7 +1499,6 @@ where
|
|||
let ndarray = NDArrayValue::from_pointer_value(
|
||||
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
||||
llvm_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
@ -2054,7 +2061,6 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||
NDArrayValue::from_pointer_value(
|
||||
this_arg.into_pointer_value(),
|
||||
llvm_elem_ty,
|
||||
None,
|
||||
llvm_usize,
|
||||
None,
|
||||
),
|
||||
|
@ -2092,7 +2098,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||
ndarray_fill_flattened(
|
||||
generator,
|
||||
context,
|
||||
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|
||||
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|
||||
|generator, ctx, _| {
|
||||
let value = if value_arg.is_pointer_value() {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
@ -2134,7 +2140,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
|
||||
// Dimensions are reversed in the transposed array
|
||||
|
@ -2254,7 +2260,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
|
||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
|
@ -2542,8 +2548,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
||||
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
||||
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
|
||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
|
||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
||||
|
||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use inkwell::context::ContextRef;
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context, ContextRef},
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
|
@ -10,13 +11,9 @@ use super::{
|
|||
structure::{FieldIndexCounter, StructField, StructFields},
|
||||
ProxyType,
|
||||
};
|
||||
use crate::{
|
||||
codegen::{
|
||||
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
||||
use crate::codegen::{
|
||||
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
||||
{CodeGenContext, CodeGenerator},
|
||||
},
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// Proxy type for a `ndarray` type in LLVM.
|
||||
|
@ -24,17 +21,13 @@ use crate::{
|
|||
pub struct NDArrayType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
// TODO(Derppening): Make this non-optional
|
||||
ndims: Option<u64>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDArrayStructFields<'ctx> {
|
||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
|
@ -44,18 +37,12 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
|||
let mut counter = FieldIndexCounter::default();
|
||||
|
||||
NDArrayStructFields {
|
||||
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
||||
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||
shape: StructField::create(
|
||||
&mut counter,
|
||||
"shape",
|
||||
llvm_usize.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
strides: StructField::create(
|
||||
&mut counter,
|
||||
"strides",
|
||||
llvm_usize.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
data: StructField::create(
|
||||
&mut counter,
|
||||
"data",
|
||||
|
@ -65,13 +52,7 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
|||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||
vec![
|
||||
self.itemsize.into(),
|
||||
self.ndims.into(),
|
||||
self.shape.into(),
|
||||
self.strides.into(),
|
||||
self.data.into(),
|
||||
]
|
||||
vec![self.ndims.into(), self.shape.into(), self.data.into()]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,45 +62,70 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
llvm_ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
let ctx = llvm_ty.get_context();
|
||||
|
||||
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
|
||||
|
||||
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||
};
|
||||
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
|
||||
if llvm_ndarray_ty.count_fields() != 3 {
|
||||
return Err(format!(
|
||||
"Expected {} fields in `NDArray`, got {}",
|
||||
llvm_expected_ty.len(),
|
||||
"Expected 3 fields in `NDArray`, got {}",
|
||||
llvm_ndarray_ty.count_fields()
|
||||
));
|
||||
}
|
||||
|
||||
llvm_expected_ty
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, expected_ty)| {
|
||||
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
|
||||
})
|
||||
.try_for_each(|(expected_ty, actual_ty)| {
|
||||
if expected_ty == actual_ty {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
|
||||
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
||||
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
||||
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
|
||||
};
|
||||
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
||||
return Err(format!(
|
||||
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
|
||||
llvm_usize.get_bit_width(),
|
||||
ndarray_ndims_ty.get_bit_width()
|
||||
));
|
||||
}
|
||||
|
||||
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
||||
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
|
||||
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
|
||||
};
|
||||
let ndarray_dims = ndarray_pdims.get_element_type();
|
||||
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
|
||||
return Err(format!(
|
||||
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
|
||||
));
|
||||
};
|
||||
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
|
||||
return Err(format!(
|
||||
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||
llvm_usize.get_bit_width(),
|
||||
ndarray_dims.get_bit_width()
|
||||
));
|
||||
}
|
||||
|
||||
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
||||
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
||||
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
||||
};
|
||||
let ndarray_data = ndarray_pdata.get_element_type();
|
||||
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
||||
return Err(format!(
|
||||
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
|
||||
));
|
||||
};
|
||||
if ndarray_data.get_bit_width() != 8 {
|
||||
return Err(format!(
|
||||
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||
ndarray_data.get_bit_width()
|
||||
));
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: Move this into e.g. StructProxyType
|
||||
#[must_use]
|
||||
fn fields(
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
||||
NDArrayStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
|
@ -127,7 +133,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
#[must_use]
|
||||
pub fn get_fields(
|
||||
&self,
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
ctx: &'ctx Context,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
Self::fields(ctx, llvm_usize)
|
||||
|
@ -136,7 +142,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
// struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
|
||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||
//
|
||||
// * data : Pointer to an array containing the array data
|
||||
// * itemsize: The size of each NDArray elements in bytes
|
||||
|
@ -159,28 +165,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
let llvm_usize = generator.get_size_type(ctx);
|
||||
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
NDArrayType { ty: llvm_ndarray, dtype, ndims: None, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
NDArrayType {
|
||||
ty: Self::llvm_type(ctx.ctx, llvm_usize),
|
||||
dtype: llvm_dtype,
|
||||
ndims: Some(ndims),
|
||||
llvm_usize,
|
||||
}
|
||||
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||
|
@ -192,7 +177,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
) -> Self {
|
||||
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
NDArrayType { ty: ptr_ty, dtype, ndims: None, llvm_usize }
|
||||
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the type of the `size` field of this `ndarray` type.
|
||||
|
@ -201,7 +186,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
self.as_base_type()
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.get_field_type_at_index(1)
|
||||
.get_field_type_at_index(0)
|
||||
.map(BasicTypeEnum::into_int_type)
|
||||
.unwrap()
|
||||
}
|
||||
|
@ -211,114 +196,6 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions represented by this [`NDArrayType`], or [`None`] if it is
|
||||
/// not known.
|
||||
#[must_use]
|
||||
pub fn ndims_as_value(&self) -> Option<IntValue<'ctx>> {
|
||||
self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false))
|
||||
}
|
||||
|
||||
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
||||
///
|
||||
/// `shape` and `strides` will be automatically allocated onto the stack.
|
||||
///
|
||||
/// The returned ndarray's content will be:
|
||||
/// - `data`: uninitialized.
|
||||
/// - `itemsize`: set to the `sizeof()` of `dtype`.
|
||||
/// - `ndims`: set to the value of `ndims`.
|
||||
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
||||
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
||||
#[must_use]
|
||||
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndims: Option<u64>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.new_value(generator, ctx, name);
|
||||
|
||||
let itemsize = ctx
|
||||
.builder
|
||||
.build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
|
||||
.unwrap();
|
||||
ndarray.store_itemsize(ctx, generator, itemsize);
|
||||
|
||||
let ndims_val = self.llvm_usize.const_int(ndims.or(self.ndims).unwrap(), false);
|
||||
ndarray.store_ndims(ctx, generator, ndims_val);
|
||||
|
||||
ndarray.create_shape(ctx, self.llvm_usize, ndims_val);
|
||||
ndarray.create_strides(ctx, self.llvm_usize, ndims_val);
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||
#[must_use]
|
||||
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &[u64],
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name);
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
let dim = self.llvm_usize.const_int(*dim, false);
|
||||
unsafe {
|
||||
ndarray_shape.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&self.llvm_usize.const_int(i as u64, false),
|
||||
dim,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||
#[must_use]
|
||||
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &[IntValue<'ctx>],
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_uninitialized(generator, ctx, Some(shape.len() as u64), name);
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
assert_eq!(
|
||||
dim.get_type(),
|
||||
self.llvm_usize,
|
||||
"Expected {} but got {}",
|
||||
self.llvm_usize.print_to_string(),
|
||||
dim.get_type().print_to_string()
|
||||
);
|
||||
unsafe {
|
||||
ndarray_shape.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&self.llvm_usize.const_int(i as u64, false),
|
||||
*dim,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||
|
@ -387,7 +264,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
|||
) -> Self::Value {
|
||||
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||
|
||||
NDArrayValue::from_pointer_value(value, self.dtype, self.ndims, self.llvm_usize, name)
|
||||
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
|
|
|
@ -86,7 +86,7 @@ where
|
|||
/// index.
|
||||
/// * `name` - Name of the field.
|
||||
/// * `ty` - The type of this field.
|
||||
pub(super) fn create(
|
||||
pub fn create(
|
||||
idx_counter: &mut FieldIndexCounter,
|
||||
name: &'static str,
|
||||
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||
|
@ -99,11 +99,7 @@ where
|
|||
/// * `index` - The index of this field within its enclosing structure.
|
||||
/// * `name` - Name of the field.
|
||||
/// * `ty` - The type of this field.
|
||||
pub(super) fn create_at(
|
||||
index: u32,
|
||||
name: &'static str,
|
||||
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||
) -> Self {
|
||||
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
|
||||
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
||||
}
|
||||
|
||||
|
@ -193,7 +189,7 @@ where
|
|||
|
||||
/// A counter that tracks the next index of a field using a monotonically increasing counter.
|
||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub(super) struct FieldIndexCounter(u32);
|
||||
pub struct FieldIndexCounter(u32);
|
||||
|
||||
impl FieldIndexCounter {
|
||||
/// Increments the number stored by this counter, returning the previous value.
|
||||
|
|
|
@ -21,7 +21,6 @@ use crate::codegen::{
|
|||
pub struct NDArrayValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: Option<u64>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
@ -41,13 +40,12 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: Option<u64>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||
|
||||
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
|
||||
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||
}
|
||||
|
||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||
|
@ -77,33 +75,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||
}
|
||||
|
||||
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
|
||||
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.itemsize
|
||||
.ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the size of each element `itemsize` into this instance.
|
||||
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
ndims: IntValue<'ctx>,
|
||||
) {
|
||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||
|
||||
let pndims = self.ptr_to_ndims(ctx);
|
||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||
}
|
||||
|
||||
/// Returns the size of each element of this `NDArray` as a value.
|
||||
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
let pndims = self.ptr_to_ndims(ctx);
|
||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||
/// `getelementptr` on the field.
|
||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
|
@ -134,36 +105,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
NDArrayShapeProxy(self)
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `stride` array, as if by calling
|
||||
/// `getelementptr` on the field.
|
||||
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.strides
|
||||
.ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the array of dimension sizes `dims` into this instance.
|
||||
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing the stride with the given `size`.
|
||||
pub fn create_strides(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
) {
|
||||
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||
}
|
||||
|
||||
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
|
||||
#[must_use]
|
||||
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
|
||||
NDArrayStridesProxy(self)
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
|
@ -227,6 +168,103 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||
|
||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||
}
|
||||
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder
|
||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
self.0.load_ndims(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||
|
||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn downcast_to_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
value.into_int_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn upcast_from_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: IntValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
value.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||
|
@ -477,197 +515,3 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
|||
for NDArrayDataProxy<'ctx, '_>
|
||||
{
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||
|
||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||
}
|
||||
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder
|
||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
self.0.load_ndims(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||
|
||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn downcast_to_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
value.into_int_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||
fn upcast_from_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: IntValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
value.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||
|
||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||
}
|
||||
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder
|
||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
self.0.load_ndims(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
||||
|
||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||
fn downcast_to_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
value.into_int_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||
fn upcast_from_type(
|
||||
&self,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
value: IntValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
value.into()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1759,14 +1759,14 @@ def run() -> int32:
|
|||
test_ndarray_reshape()
|
||||
|
||||
test_ndarray_dot()
|
||||
# test_ndarray_cholesky()
|
||||
# test_ndarray_qr()
|
||||
# test_ndarray_svd()
|
||||
# test_ndarray_linalg_inv()
|
||||
# test_ndarray_pinv()
|
||||
# test_ndarray_matrix_power()
|
||||
# test_ndarray_det()
|
||||
# test_ndarray_lu()
|
||||
# test_ndarray_schur()
|
||||
# test_ndarray_hessenberg()
|
||||
test_ndarray_cholesky()
|
||||
test_ndarray_qr()
|
||||
test_ndarray_svd()
|
||||
test_ndarray_linalg_inv()
|
||||
test_ndarray_pinv()
|
||||
test_ndarray_matrix_power()
|
||||
test_ndarray_det()
|
||||
test_ndarray_lu()
|
||||
test_ndarray_schur()
|
||||
test_ndarray_hessenberg()
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue