forked from M-Labs/nac3
Compare commits
28 Commits
ndarray-st
...
ndarray-st
Author | SHA1 | Date |
---|---|---|
lyken | 059b130aff | |
lyken | 7742fbf9e0 | |
lyken | 71e05c17b9 | |
lyken | e4998ccec8 | |
lyken | 9a82b033b6 | |
lyken | 3b87bd36f3 | |
lyken | b4d5b2a41f | |
lyken | 259958aded | |
lyken | 867f6ccf8e | |
lyken | 23ed5642fb | |
lyken | 2f7e75d7cf | |
lyken | 8863cd64a9 | |
lyken | 9e78139373 | |
lyken | 259481e8d0 | |
lyken | 5faac4b9d4 | |
lyken | c4d54b198b | |
lyken | 9ad7a78dbe | |
lyken | 1721ebac66 | |
lyken | f033639415 | |
lyken | 3116f11814 | |
lyken | 5047379ac0 | |
lyken | 6c10e3d056 | |
lyken | 2dbc1ec659 | |
Sebastien Bourdeauducq | c80378063a | |
abdul124 | 513d30152b | |
abdul124 | 45e9360c4d | |
abdul124 | 2e01b77fc8 | |
abdul124 | cea7cade51 |
|
@ -3,20 +3,34 @@ use std::{
|
||||||
env,
|
env,
|
||||||
fs::File,
|
fs::File,
|
||||||
io::Write,
|
io::Write,
|
||||||
path::Path,
|
path::{Path, PathBuf},
|
||||||
process::{Command, Stdio},
|
process::{Command, Stdio},
|
||||||
};
|
};
|
||||||
|
|
||||||
fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
|
const CMD_IRRT_CLANG: &str = "clang-irrt";
|
||||||
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
const CMD_IRRT_CLANG_TEST: &str = "clang-irrt-test";
|
||||||
|
const CMD_IRRT_LLVM_AS: &str = "llvm-as-irrt";
|
||||||
|
|
||||||
|
fn get_out_dir() -> PathBuf {
|
||||||
|
PathBuf::from(env::var("OUT_DIR").unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_irrt_dir() -> &'static Path {
|
||||||
|
Path::new("irrt")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compile `irrt.cpp` for use in `src/codegen`
|
||||||
|
fn compile_irrt_cpp() {
|
||||||
|
let out_dir = get_out_dir();
|
||||||
|
let irrt_dir = get_irrt_dir();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||||
*/
|
*/
|
||||||
|
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
||||||
let flags: &[&str] = &[
|
let flags: &[&str] = &[
|
||||||
"--target=wasm32",
|
"--target=wasm32",
|
||||||
irrt_cpp_path.to_str().unwrap(),
|
|
||||||
"-x",
|
"-x",
|
||||||
"c++",
|
"c++",
|
||||||
"-fno-discard-value-names",
|
"-fno-discard-value-names",
|
||||||
|
@ -36,11 +50,14 @@ fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
|
||||||
irrt_dir.to_str().unwrap(),
|
irrt_dir.to_str().unwrap(),
|
||||||
"-o",
|
"-o",
|
||||||
"-",
|
"-",
|
||||||
|
irrt_cpp_path.to_str().unwrap(),
|
||||||
];
|
];
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
|
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
|
||||||
|
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||||
|
|
||||||
let output = Command::new("clang-irrt")
|
// Compile IRRT and capture the LLVM IR output
|
||||||
|
let output = Command::new(CMD_IRRT_CLANG)
|
||||||
.args(flags)
|
.args(flags)
|
||||||
.output()
|
.output()
|
||||||
.map(|o| {
|
.map(|o| {
|
||||||
|
@ -53,11 +70,17 @@ fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
|
||||||
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
||||||
let mut filtered_output = String::with_capacity(output.len());
|
let mut filtered_output = String::with_capacity(output.len());
|
||||||
|
|
||||||
// (?ms:^define.*?\}$) to capture `define` blocks
|
// Filter out irrelevant IR
|
||||||
// (?m:^declare.*?$) to capture `declare` blocks
|
//
|
||||||
// (?m:^%.+?=\s*type\s*\{.+?\}$) to capture `type` declarations
|
// Regex:
|
||||||
let regex_filter =
|
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
|
||||||
Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)").unwrap();
|
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
|
||||||
|
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
|
||||||
|
// - `(?m:^@.+?=.+$)` captures global constants
|
||||||
|
let regex_filter = Regex::new(
|
||||||
|
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
for f in regex_filter.captures_iter(&output) {
|
for f in regex_filter.captures_iter(&output) {
|
||||||
assert_eq!(f.len(), 1);
|
assert_eq!(f.len(), 1);
|
||||||
filtered_output.push_str(&f[0]);
|
filtered_output.push_str(&f[0]);
|
||||||
|
@ -68,15 +91,21 @@ fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.replace_all(&filtered_output, "");
|
.replace_all(&filtered_output, "");
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
|
// For debugging
|
||||||
if env::var("DEBUG_DUMP_IRRT").is_ok() {
|
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
|
||||||
|
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
|
||||||
|
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
|
||||||
|
if env::var(DEBUG_DUMP_IRRT).is_ok() {
|
||||||
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
|
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
|
||||||
file.write_all(output.as_bytes()).unwrap();
|
file.write_all(output.as_bytes()).unwrap();
|
||||||
|
|
||||||
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
|
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
|
||||||
file.write_all(filtered_output.as_bytes()).unwrap();
|
file.write_all(filtered_output.as_bytes()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut llvm_as = Command::new("llvm-as-irrt")
|
// Assemble the emitted and filtered IR to .bc
|
||||||
|
// That .bc will be integrated into nac3core's codegen
|
||||||
|
let mut llvm_as = Command::new(CMD_IRRT_LLVM_AS)
|
||||||
.stdin(Stdio::piped())
|
.stdin(Stdio::piped())
|
||||||
.arg("-o")
|
.arg("-o")
|
||||||
.arg(out_dir.join("irrt.bc"))
|
.arg(out_dir.join("irrt.bc"))
|
||||||
|
@ -86,10 +115,13 @@ fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
|
||||||
assert!(llvm_as.wait().unwrap().success());
|
assert!(llvm_as.wait().unwrap().success());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) {
|
/// Compile `irrt_test.cpp` for testing
|
||||||
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
|
fn compile_irrt_test_cpp() {
|
||||||
let exe_path = out_dir.join("irrt_test.out");
|
let out_dir = get_out_dir();
|
||||||
|
let irrt_dir = get_irrt_dir();
|
||||||
|
|
||||||
|
let exe_path = out_dir.join("irrt_test.out"); // Output path of the compiled test executable
|
||||||
|
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
|
||||||
let flags: &[&str] = &[
|
let flags: &[&str] = &[
|
||||||
irrt_test_cpp_path.to_str().unwrap(),
|
irrt_test_cpp_path.to_str().unwrap(),
|
||||||
"-x",
|
"-x",
|
||||||
|
@ -103,11 +135,13 @@ fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) {
|
||||||
"-Wextra",
|
"-Wextra",
|
||||||
"-Werror=return-type",
|
"-Werror=return-type",
|
||||||
"-lm", // for `tgamma()`, `lgamma()`
|
"-lm", // for `tgamma()`, `lgamma()`
|
||||||
|
"-I",
|
||||||
|
irrt_dir.to_str().unwrap(),
|
||||||
"-o",
|
"-o",
|
||||||
exe_path.to_str().unwrap(),
|
exe_path.to_str().unwrap(),
|
||||||
];
|
];
|
||||||
|
|
||||||
Command::new("clang-irrt-test")
|
Command::new(CMD_IRRT_CLANG_TEST)
|
||||||
.args(flags)
|
.args(flags)
|
||||||
.output()
|
.output()
|
||||||
.map(|o| {
|
.map(|o| {
|
||||||
|
@ -115,20 +149,15 @@ fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) {
|
||||||
o
|
o
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
|
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let out_dir = env::var("OUT_DIR").unwrap();
|
compile_irrt_cpp();
|
||||||
let out_dir = Path::new(&out_dir);
|
|
||||||
|
|
||||||
let irrt_dir = Path::new("./irrt");
|
|
||||||
|
|
||||||
compile_irrt(irrt_dir, out_dir);
|
|
||||||
|
|
||||||
// https://github.com/rust-lang/cargo/issues/2549
|
// https://github.com/rust-lang/cargo/issues/2549
|
||||||
// `cargo test -F test` to also build `irrt_test.cpp
|
// `cargo test -F test` to also build `irrt_test.cpp
|
||||||
if cfg!(feature = "test") {
|
if cfg!(feature = "test") {
|
||||||
compile_irrt_test(irrt_dir, out_dir);
|
compile_irrt_test_cpp();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
#include "irrt_everything.hpp"
|
#define IRRT_DEFINE_TYPEDEF_INTS
|
||||||
|
#include <irrt_everything.hpp>
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file will be read by `clang-irrt` to conveniently produce LLVM IR for `nac3core/codegen`.
|
All IRRT implementations.
|
||||||
*/
|
|
||||||
|
We don't have any pre-compiled objects, so we are writing all implementations in headers and
|
||||||
|
concatenate them with `#include` into one massive source file that contains all the IRRT stuff.
|
||||||
|
*/
|
|
@ -1,437 +0,0 @@
|
||||||
#ifndef IRRT_DONT_TYPEDEF_INTS
|
|
||||||
typedef _BitInt(8) int8_t;
|
|
||||||
typedef unsigned _BitInt(8) uint8_t;
|
|
||||||
typedef _BitInt(32) int32_t;
|
|
||||||
typedef unsigned _BitInt(32) uint32_t;
|
|
||||||
typedef _BitInt(64) int64_t;
|
|
||||||
typedef unsigned _BitInt(64) uint64_t;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// NDArray indices are always `uint32_t`.
|
|
||||||
typedef uint32_t NDIndex;
|
|
||||||
// The type of an index or a value describing the length of a range/slice is
|
|
||||||
// always `int32_t`.
|
|
||||||
typedef int32_t SliceIndex;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static T max(T a, T b) {
|
|
||||||
return a > b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static T min(T a, T b) {
|
|
||||||
return a > b ? b : a;
|
|
||||||
}
|
|
||||||
|
|
||||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
// need to make sure `exp >= 0` before calling this function
|
|
||||||
template <typename T>
|
|
||||||
static T __nac3_int_exp_impl(T base, T exp) {
|
|
||||||
T res = 1;
|
|
||||||
/* repeated squaring method */
|
|
||||||
do {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base; /* for n odd */
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
} while (exp);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static SizeT __nac3_ndarray_calc_size_impl(
|
|
||||||
const SizeT *list_data,
|
|
||||||
SizeT list_len,
|
|
||||||
SizeT begin_idx,
|
|
||||||
SizeT end_idx
|
|
||||||
) {
|
|
||||||
__builtin_assume(end_idx <= list_len);
|
|
||||||
|
|
||||||
SizeT num_elems = 1;
|
|
||||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
|
||||||
SizeT val = list_data[i];
|
|
||||||
__builtin_assume(val > 0);
|
|
||||||
num_elems *= val;
|
|
||||||
}
|
|
||||||
return num_elems;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static void __nac3_ndarray_calc_nd_indices_impl(
|
|
||||||
SizeT index,
|
|
||||||
const SizeT *dims,
|
|
||||||
SizeT num_dims,
|
|
||||||
NDIndex *idxs
|
|
||||||
) {
|
|
||||||
SizeT stride = 1;
|
|
||||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
|
||||||
SizeT i = num_dims - dim - 1;
|
|
||||||
__builtin_assume(dims[i] > 0);
|
|
||||||
idxs[i] = (index / stride) % dims[i];
|
|
||||||
stride *= dims[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static SizeT __nac3_ndarray_flatten_index_impl(
|
|
||||||
const SizeT *dims,
|
|
||||||
SizeT num_dims,
|
|
||||||
const NDIndex *indices,
|
|
||||||
SizeT num_indices
|
|
||||||
) {
|
|
||||||
SizeT idx = 0;
|
|
||||||
SizeT stride = 1;
|
|
||||||
for (SizeT i = 0; i < num_dims; ++i) {
|
|
||||||
SizeT ri = num_dims - i - 1;
|
|
||||||
if (ri < num_indices) {
|
|
||||||
idx += stride * indices[ri];
|
|
||||||
}
|
|
||||||
|
|
||||||
__builtin_assume(dims[i] > 0);
|
|
||||||
stride *= dims[ri];
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static void __nac3_ndarray_calc_broadcast_impl(
|
|
||||||
const SizeT *lhs_dims,
|
|
||||||
SizeT lhs_ndims,
|
|
||||||
const SizeT *rhs_dims,
|
|
||||||
SizeT rhs_ndims,
|
|
||||||
SizeT *out_dims
|
|
||||||
) {
|
|
||||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
|
||||||
const SizeT *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
|
||||||
const SizeT *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
|
||||||
SizeT *out_dim = &out_dims[max_ndims - i - 1];
|
|
||||||
|
|
||||||
if (lhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (rhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == 1) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (*rhs_dim_sz == 1) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static void __nac3_ndarray_calc_broadcast_idx_impl(
|
|
||||||
const SizeT *src_dims,
|
|
||||||
SizeT src_ndims,
|
|
||||||
const NDIndex *in_idx,
|
|
||||||
NDIndex *out_idx
|
|
||||||
) {
|
|
||||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
|
||||||
SizeT src_i = src_ndims - i - 1;
|
|
||||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
static void __nac3_ndarray_strides_from_shape_impl(
|
|
||||||
SizeT ndims,
|
|
||||||
SizeT *shape,
|
|
||||||
SizeT *dst_strides
|
|
||||||
) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
int dim_i = ndims - i - 1;
|
|
||||||
dst_strides[dim_i] = stride_product;
|
|
||||||
stride_product *= shape[dim_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
#define DEF_nac3_int_exp_(T) \
|
|
||||||
T __nac3_int_exp_##T(T base, T exp) {\
|
|
||||||
return __nac3_int_exp_impl(base, exp);\
|
|
||||||
}
|
|
||||||
|
|
||||||
DEF_nac3_int_exp_(int32_t)
|
|
||||||
DEF_nac3_int_exp_(int64_t)
|
|
||||||
DEF_nac3_int_exp_(uint32_t)
|
|
||||||
DEF_nac3_int_exp_(uint64_t)
|
|
||||||
|
|
||||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
|
||||||
if (i < 0) {
|
|
||||||
i = len + i;
|
|
||||||
}
|
|
||||||
if (i < 0) {
|
|
||||||
return 0;
|
|
||||||
} else if (i > len) {
|
|
||||||
return len;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
SliceIndex __nac3_range_slice_len(
|
|
||||||
const SliceIndex start,
|
|
||||||
const SliceIndex end,
|
|
||||||
const SliceIndex step
|
|
||||||
) {
|
|
||||||
SliceIndex diff = end - start;
|
|
||||||
if (diff > 0 && step > 0) {
|
|
||||||
return ((diff - 1) / step) + 1;
|
|
||||||
} else if (diff < 0 && step < 0) {
|
|
||||||
return ((diff + 1) / step) + 1;
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle list assignment and dropping part of the list when
|
|
||||||
// both dest_step and src_step are +1.
|
|
||||||
// - All the index must *not* be out-of-bound or negative,
|
|
||||||
// - The end index is *inclusive*,
|
|
||||||
// - The length of src and dest slice size should already
|
|
||||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
|
||||||
SliceIndex __nac3_list_slice_assign_var_size(
|
|
||||||
SliceIndex dest_start,
|
|
||||||
SliceIndex dest_end,
|
|
||||||
SliceIndex dest_step,
|
|
||||||
uint8_t *dest_arr,
|
|
||||||
SliceIndex dest_arr_len,
|
|
||||||
SliceIndex src_start,
|
|
||||||
SliceIndex src_end,
|
|
||||||
SliceIndex src_step,
|
|
||||||
uint8_t *src_arr,
|
|
||||||
SliceIndex src_arr_len,
|
|
||||||
const SliceIndex size
|
|
||||||
) {
|
|
||||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
|
||||||
if (dest_arr_len == 0) return dest_arr_len;
|
|
||||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
|
||||||
if (src_step == dest_step && dest_step == 1) {
|
|
||||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
|
||||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
|
||||||
if (src_len > 0) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_start * size,
|
|
||||||
src_arr + src_start * size,
|
|
||||||
src_len * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (dest_len > 0) {
|
|
||||||
/* dropping */
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + (dest_start + src_len) * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
/* shrink size */
|
|
||||||
return dest_arr_len - (dest_len - src_len);
|
|
||||||
}
|
|
||||||
/* if two range overlaps, need alloca */
|
|
||||||
uint8_t need_alloca =
|
|
||||||
(dest_arr == src_arr)
|
|
||||||
&& !(
|
|
||||||
max(dest_start, dest_end) < min(src_start, src_end)
|
|
||||||
|| max(src_start, src_end) < min(dest_start, dest_end)
|
|
||||||
);
|
|
||||||
if (need_alloca) {
|
|
||||||
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
|
|
||||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
|
||||||
src_arr = tmp;
|
|
||||||
}
|
|
||||||
SliceIndex src_ind = src_start;
|
|
||||||
SliceIndex dest_ind = dest_start;
|
|
||||||
for (;
|
|
||||||
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
|
|
||||||
src_ind += src_step, dest_ind += dest_step
|
|
||||||
) {
|
|
||||||
/* for constant optimization */
|
|
||||||
if (size == 1) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
|
||||||
} else if (size == 4) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
|
||||||
} else if (size == 8) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
|
||||||
} else {
|
|
||||||
/* memcpy for var size, cannot overlap after previous alloca */
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/* only dest_step == 1 can we shrink the dest list. */
|
|
||||||
/* size should be ensured prior to calling this function */
|
|
||||||
if (dest_step == 1 && dest_end >= dest_start) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_ind * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
|
||||||
}
|
|
||||||
return dest_arr_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isinf(double x) {
|
|
||||||
return __builtin_isinf(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isnan(double x) {
|
|
||||||
return __builtin_isnan(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double tgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gamma(double z) {
|
|
||||||
// Handling for denormals
|
|
||||||
// | x | Python gamma(x) | C tgamma(x) |
|
|
||||||
// --- | ----------------- | --------------- | ----------- |
|
|
||||||
// (1) | nan | nan | nan |
|
|
||||||
// (2) | -inf | -inf | inf |
|
|
||||||
// (3) | inf | inf | inf |
|
|
||||||
// (4) | 0.0 | inf | inf |
|
|
||||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
|
||||||
|
|
||||||
// (1)-(3)
|
|
||||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
|
||||||
return z;
|
|
||||||
}
|
|
||||||
|
|
||||||
double v = tgamma(z);
|
|
||||||
|
|
||||||
// (4)-(5)
|
|
||||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
double lgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gammaln(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: gammaln(-inf) -> -inf
|
|
||||||
// - libm : lgamma(-inf) -> inf
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return lgamma(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double j0(double x);
|
|
||||||
|
|
||||||
double __nac3_j0(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: j0(inf) -> nan
|
|
||||||
// - libm : j0(inf) -> 0.0
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return __builtin_nan("");
|
|
||||||
}
|
|
||||||
|
|
||||||
return j0(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_calc_size(
|
|
||||||
const uint32_t *list_data,
|
|
||||||
uint32_t list_len,
|
|
||||||
uint32_t begin_idx,
|
|
||||||
uint32_t end_idx
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_calc_size64(
|
|
||||||
const uint64_t *list_data,
|
|
||||||
uint64_t list_len,
|
|
||||||
uint64_t begin_idx,
|
|
||||||
uint64_t end_idx
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_nd_indices(
|
|
||||||
uint32_t index,
|
|
||||||
const uint32_t* dims,
|
|
||||||
uint32_t num_dims,
|
|
||||||
NDIndex* idxs
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_nd_indices64(
|
|
||||||
uint64_t index,
|
|
||||||
const uint64_t* dims,
|
|
||||||
uint64_t num_dims,
|
|
||||||
NDIndex* idxs
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_flatten_index(
|
|
||||||
const uint32_t* dims,
|
|
||||||
uint32_t num_dims,
|
|
||||||
const NDIndex* indices,
|
|
||||||
uint32_t num_indices
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_flatten_index64(
|
|
||||||
const uint64_t* dims,
|
|
||||||
uint64_t num_dims,
|
|
||||||
const NDIndex* indices,
|
|
||||||
uint64_t num_indices
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast(
|
|
||||||
const uint32_t *lhs_dims,
|
|
||||||
uint32_t lhs_ndims,
|
|
||||||
const uint32_t *rhs_dims,
|
|
||||||
uint32_t rhs_ndims,
|
|
||||||
uint32_t *out_dims
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast64(
|
|
||||||
const uint64_t *lhs_dims,
|
|
||||||
uint64_t lhs_ndims,
|
|
||||||
const uint64_t *rhs_dims,
|
|
||||||
uint64_t rhs_ndims,
|
|
||||||
uint64_t *out_dims
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx(
|
|
||||||
const uint32_t *src_dims,
|
|
||||||
uint32_t src_ndims,
|
|
||||||
const NDIndex *in_idx,
|
|
||||||
NDIndex *out_idx
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx64(
|
|
||||||
const uint64_t *src_dims,
|
|
||||||
uint64_t src_ndims,
|
|
||||||
const NDIndex *in_idx,
|
|
||||||
NDIndex *out_idx
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_strides_from_shape(uint32_t ndims, uint32_t* shape, uint32_t* dst_strides) {
|
|
||||||
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_strides_from_shape64(uint64_t ndims, uint64_t* shape, uint64_t* dst_strides) {
|
|
||||||
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,402 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
||||||
|
// NDArray indices are always `uint32_t`.
|
||||||
|
using NDIndex = uint32_t;
|
||||||
|
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||||
|
using SliceIndex = int32_t;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
|
// need to make sure `exp >= 0` before calling this function
|
||||||
|
template <typename T>
|
||||||
|
T __nac3_int_exp_impl(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
/* repeated squaring method */
|
||||||
|
do {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base; /* for n odd */
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
} while (exp);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_calc_size_impl(
|
||||||
|
const SizeT* list_data,
|
||||||
|
SizeT list_len,
|
||||||
|
SizeT begin_idx,
|
||||||
|
SizeT end_idx
|
||||||
|
) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
|
SizeT num_elems = 1;
|
||||||
|
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||||
|
SizeT val = list_data[i];
|
||||||
|
__builtin_assume(val > 0);
|
||||||
|
num_elems *= val;
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_nd_indices_impl(
|
||||||
|
SizeT index,
|
||||||
|
const SizeT* dims,
|
||||||
|
SizeT num_dims,
|
||||||
|
NDIndex* idxs
|
||||||
|
) {
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||||
|
SizeT i = num_dims - dim - 1;
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_flatten_index_impl(
|
||||||
|
const SizeT* dims,
|
||||||
|
SizeT num_dims,
|
||||||
|
const NDIndex* indices,
|
||||||
|
SizeT num_indices
|
||||||
|
) {
|
||||||
|
SizeT idx = 0;
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT i = 0; i < num_dims; ++i) {
|
||||||
|
SizeT ri = num_dims - i - 1;
|
||||||
|
if (ri < num_indices) {
|
||||||
|
idx += stride * indices[ri];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
stride *= dims[ri];
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_impl(
|
||||||
|
const SizeT* lhs_dims,
|
||||||
|
SizeT lhs_ndims,
|
||||||
|
const SizeT* rhs_dims,
|
||||||
|
SizeT rhs_ndims,
|
||||||
|
SizeT* out_dims
|
||||||
|
) {
|
||||||
|
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||||
|
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||||
|
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||||
|
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||||
|
|
||||||
|
if (lhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (rhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == 1) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (*rhs_dim_sz == 1) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx_impl(
|
||||||
|
const SizeT* src_dims,
|
||||||
|
SizeT src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||||
|
SizeT src_i = src_ndims - i - 1;
|
||||||
|
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
#define DEF_nac3_int_exp_(T) \
|
||||||
|
T __nac3_int_exp_##T(T base, T exp) {\
|
||||||
|
return __nac3_int_exp_impl(base, exp);\
|
||||||
|
}
|
||||||
|
|
||||||
|
DEF_nac3_int_exp_(int32_t)
|
||||||
|
DEF_nac3_int_exp_(int64_t)
|
||||||
|
DEF_nac3_int_exp_(uint32_t)
|
||||||
|
DEF_nac3_int_exp_(uint64_t)
|
||||||
|
|
||||||
|
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||||
|
if (i < 0) {
|
||||||
|
i = len + i;
|
||||||
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
return 0;
|
||||||
|
} else if (i > len) {
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_range_slice_len(
|
||||||
|
const SliceIndex start,
|
||||||
|
const SliceIndex end,
|
||||||
|
const SliceIndex step
|
||||||
|
) {
|
||||||
|
SliceIndex diff = end - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle list assignment and dropping part of the list when
|
||||||
|
// both dest_step and src_step are +1.
|
||||||
|
// - All the index must *not* be out-of-bound or negative,
|
||||||
|
// - The end index is *inclusive*,
|
||||||
|
// - The length of src and dest slice size should already
|
||||||
|
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
||||||
|
SliceIndex __nac3_list_slice_assign_var_size(
|
||||||
|
SliceIndex dest_start,
|
||||||
|
SliceIndex dest_end,
|
||||||
|
SliceIndex dest_step,
|
||||||
|
uint8_t* dest_arr,
|
||||||
|
SliceIndex dest_arr_len,
|
||||||
|
SliceIndex src_start,
|
||||||
|
SliceIndex src_end,
|
||||||
|
SliceIndex src_step,
|
||||||
|
uint8_t* src_arr,
|
||||||
|
SliceIndex src_arr_len,
|
||||||
|
const SliceIndex size
|
||||||
|
) {
|
||||||
|
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||||
|
if (dest_arr_len == 0) return dest_arr_len;
|
||||||
|
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
||||||
|
if (src_step == dest_step && dest_step == 1) {
|
||||||
|
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||||
|
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||||
|
if (src_len > 0) {
|
||||||
|
__builtin_memmove(
|
||||||
|
dest_arr + dest_start * size,
|
||||||
|
src_arr + src_start * size,
|
||||||
|
src_len * size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (dest_len > 0) {
|
||||||
|
/* dropping */
|
||||||
|
__builtin_memmove(
|
||||||
|
dest_arr + (dest_start + src_len) * size,
|
||||||
|
dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
/* shrink size */
|
||||||
|
return dest_arr_len - (dest_len - src_len);
|
||||||
|
}
|
||||||
|
/* if two range overlaps, need alloca */
|
||||||
|
uint8_t need_alloca =
|
||||||
|
(dest_arr == src_arr)
|
||||||
|
&& !(
|
||||||
|
max(dest_start, dest_end) < min(src_start, src_end)
|
||||||
|
|| max(src_start, src_end) < min(dest_start, dest_end)
|
||||||
|
);
|
||||||
|
if (need_alloca) {
|
||||||
|
uint8_t* tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
|
||||||
|
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||||
|
src_arr = tmp;
|
||||||
|
}
|
||||||
|
SliceIndex src_ind = src_start;
|
||||||
|
SliceIndex dest_ind = dest_start;
|
||||||
|
for (;
|
||||||
|
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
|
||||||
|
src_ind += src_step, dest_ind += dest_step
|
||||||
|
) {
|
||||||
|
/* for constant optimization */
|
||||||
|
if (size == 1) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||||
|
} else if (size == 4) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||||
|
} else if (size == 8) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||||
|
} else {
|
||||||
|
/* memcpy for var size, cannot overlap after previous alloca */
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* only dest_step == 1 can we shrink the dest list. */
|
||||||
|
/* size should be ensured prior to calling this function */
|
||||||
|
if (dest_step == 1 && dest_end >= dest_start) {
|
||||||
|
__builtin_memmove(
|
||||||
|
dest_arr + dest_ind * size,
|
||||||
|
dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size
|
||||||
|
);
|
||||||
|
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||||
|
}
|
||||||
|
return dest_arr_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_isinf(double x) {
|
||||||
|
return __builtin_isinf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_isnan(double x) {
|
||||||
|
return __builtin_isnan(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double tgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gamma(double z) {
|
||||||
|
// Handling for denormals
|
||||||
|
// | x | Python gamma(x) | C tgamma(x) |
|
||||||
|
// --- | ----------------- | --------------- | ----------- |
|
||||||
|
// (1) | nan | nan | nan |
|
||||||
|
// (2) | -inf | -inf | inf |
|
||||||
|
// (3) | inf | inf | inf |
|
||||||
|
// (4) | 0.0 | inf | inf |
|
||||||
|
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||||
|
|
||||||
|
// (1)-(3)
|
||||||
|
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
double v = tgamma(z);
|
||||||
|
|
||||||
|
// (4)-(5)
|
||||||
|
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
double lgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gammaln(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: gammaln(-inf) -> -inf
|
||||||
|
// - libm : lgamma(-inf) -> inf
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lgamma(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double j0(double x);
|
||||||
|
|
||||||
|
double __nac3_j0(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: j0(inf) -> nan
|
||||||
|
// - libm : j0(inf) -> 0.0
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return __builtin_nan("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return j0(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_calc_size(
|
||||||
|
const uint32_t* list_data,
|
||||||
|
uint32_t list_len,
|
||||||
|
uint32_t begin_idx,
|
||||||
|
uint32_t end_idx
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_calc_size64(
|
||||||
|
const uint64_t* list_data,
|
||||||
|
uint64_t list_len,
|
||||||
|
uint64_t begin_idx,
|
||||||
|
uint64_t end_idx
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices(
|
||||||
|
uint32_t index,
|
||||||
|
const uint32_t* dims,
|
||||||
|
uint32_t num_dims,
|
||||||
|
NDIndex* idxs
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices64(
|
||||||
|
uint64_t index,
|
||||||
|
const uint64_t* dims,
|
||||||
|
uint64_t num_dims,
|
||||||
|
NDIndex* idxs
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_flatten_index(
|
||||||
|
const uint32_t* dims,
|
||||||
|
uint32_t num_dims,
|
||||||
|
const NDIndex* indices,
|
||||||
|
uint32_t num_indices
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_flatten_index64(
|
||||||
|
const uint64_t* dims,
|
||||||
|
uint64_t num_dims,
|
||||||
|
const NDIndex* indices,
|
||||||
|
uint64_t num_indices
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast(
|
||||||
|
const uint32_t* lhs_dims,
|
||||||
|
uint32_t lhs_ndims,
|
||||||
|
const uint32_t* rhs_dims,
|
||||||
|
uint32_t rhs_ndims,
|
||||||
|
uint32_t* out_dims
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast64(
|
||||||
|
const uint64_t* lhs_dims,
|
||||||
|
uint64_t lhs_ndims,
|
||||||
|
const uint64_t* rhs_dims,
|
||||||
|
uint64_t rhs_ndims,
|
||||||
|
uint64_t* out_dims
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx(
|
||||||
|
const uint32_t* src_dims,
|
||||||
|
uint32_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx64(
|
||||||
|
const uint64_t* src_dims,
|
||||||
|
uint64_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
} // extern "C"
|
|
@ -0,0 +1,85 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// nac3core's "str" struct type definition
|
||||||
|
template <typename SizeT>
|
||||||
|
struct Str {
|
||||||
|
const char* content;
|
||||||
|
SizeT length;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A limited set of errors IRRT could use.
|
||||||
|
typedef uint32_t ErrorId;
|
||||||
|
struct ErrorIds {
|
||||||
|
ErrorId index_error;
|
||||||
|
ErrorId value_error;
|
||||||
|
ErrorId assertion_error;
|
||||||
|
ErrorId runtime_error;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErrorContext {
|
||||||
|
// Context
|
||||||
|
ErrorIds* error_ids;
|
||||||
|
|
||||||
|
// Error thrown by IRRT
|
||||||
|
ErrorId error_id;
|
||||||
|
const char* message_template; // MUST BE `&'static`
|
||||||
|
uint64_t param1;
|
||||||
|
uint64_t param2;
|
||||||
|
uint64_t param3;
|
||||||
|
|
||||||
|
void initialize(ErrorIds* error_ids) {
|
||||||
|
this->error_ids = error_ids;
|
||||||
|
clear_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear_error() {
|
||||||
|
// Point the message_template to an empty str. Don't set it to nullptr as a sentinel
|
||||||
|
this->message_template = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_error(ErrorId error_id, const char* message, uint64_t param1 = 0, uint64_t param2 = 0, uint64_t param3 = 0) {
|
||||||
|
this->error_id = error_id;
|
||||||
|
this->message_template = message;
|
||||||
|
this->param1 = param1;
|
||||||
|
this->param2 = param2;
|
||||||
|
this->param3 = param3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_error() {
|
||||||
|
return !cstr_utils::is_empty(message_template);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_error_str(Str<SizeT> *dst_str) {
|
||||||
|
dst_str->content = message_template;
|
||||||
|
dst_str->length = (SizeT) cstr_utils::length(message_template);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void __nac3_error_context_initialize(ErrorContext* errctx, ErrorIds* error_ids) {
|
||||||
|
errctx->initialize(error_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t __nac3_error_context_has_no_error(ErrorContext* errctx) {
|
||||||
|
return !errctx->has_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_error_context_get_error_str(ErrorContext* errctx, Str<int32_t> *dst_str) {
|
||||||
|
errctx->set_error_str<int32_t>(dst_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str<int64_t> *dst_str) {
|
||||||
|
errctx->set_error_str<int64_t>(dst_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_error_dummy_raise(ErrorContext* errctx) {
|
||||||
|
errctx->set_error(errctx->error_ids->runtime_error, "THROWN FROM __nac3_error_dummy_raise!!!!!!");
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,14 +1,12 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// This is made toggleable since `irrt_test.cpp` itself would include
|
// This is made toggleable since `irrt_test.cpp` itself would include
|
||||||
// headers that define the `int_t` family.
|
// headers that define these typedefs
|
||||||
#ifndef IRRT_DONT_TYPEDEF_INTS
|
#ifdef IRRT_DEFINE_TYPEDEF_INTS
|
||||||
typedef _BitInt(8) int8_t;
|
typedef _BitInt(8) int8_t;
|
||||||
typedef unsigned _BitInt(8) uint8_t;
|
typedef unsigned _BitInt(8) uint8_t;
|
||||||
typedef _BitInt(32) int32_t;
|
typedef _BitInt(32) int32_t;
|
||||||
typedef unsigned _BitInt(32) uint32_t;
|
typedef unsigned _BitInt(32) uint32_t;
|
||||||
typedef _BitInt(64) int64_t;
|
typedef _BitInt(64) int64_t;
|
||||||
typedef unsigned _BitInt(64) uint64_t;
|
typedef unsigned _BitInt(64) uint64_t;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef int32_t SliceIndex;
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/ndarray/ndarray_util.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||||||
|
//
|
||||||
|
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||||||
|
//
|
||||||
|
// Some resources you might find helpful:
|
||||||
|
// - The official numpy implementations:
|
||||||
|
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||||
|
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||||||
|
// - https://ajcr.net/stride-guide-part-1/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-2/.
|
||||||
|
// - https://ajcr.net/stride-guide-part-3/.
|
||||||
|
template <typename SizeT>
|
||||||
|
struct NDArray {
|
||||||
|
// The underlying data this `ndarray` is pointing to.
|
||||||
|
//
|
||||||
|
// NOTE: Formally this should be of type `void *`, but clang
|
||||||
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||||
|
// so we will put `uint8_t *` here for clarity.
|
||||||
|
//
|
||||||
|
// This pointer should point to the first element of the ndarray directly
|
||||||
|
uint8_t *data;
|
||||||
|
|
||||||
|
// The number of bytes of a single element in `data`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT itemsize;
|
||||||
|
|
||||||
|
// The number of dimensions of this shape.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
SizeT ndims;
|
||||||
|
|
||||||
|
// Array shape, with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `unsigned`.
|
||||||
|
//
|
||||||
|
// NOTE: `shape` can contain 0.
|
||||||
|
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||||||
|
SizeT *shape;
|
||||||
|
|
||||||
|
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||||||
|
//
|
||||||
|
// The `SizeT` is treated as `signed`.
|
||||||
|
//
|
||||||
|
// NOTE: `strides` can have negative numbers.
|
||||||
|
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||||||
|
SizeT *strides;
|
||||||
|
|
||||||
|
// Calculate the size/# of elements of an `ndarray`.
|
||||||
|
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||||||
|
SizeT size() {
|
||||||
|
return ndarray_util::calc_size_from_shape(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||||||
|
// This function corresponds to `ndarray.nbytes`
|
||||||
|
SizeT nbytes() {
|
||||||
|
return this->size() * itemsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||||||
|
void set_strides_by_shape() {
|
||||||
|
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t* get_pelement_by_indices(const SizeT *indices) {
|
||||||
|
uint8_t* element = data;
|
||||||
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
||||||
|
element += indices[dim_i] * strides[dim_i];
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t* get_nth_pelement(SizeT nth) {
|
||||||
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
||||||
|
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
||||||
|
return get_pelement_by_indices(indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the pointer to the nth element of the ndarray as if it were flattened.
|
||||||
|
uint8_t* checked_get_nth_pelement(ErrorContext* errctx, SizeT nth) {
|
||||||
|
SizeT arr_size = this->size();
|
||||||
|
if (!(0 <= nth && nth < arr_size)) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->index_error,
|
||||||
|
"index {0} is out of bounds, valid range is {1} <= index < {2}",
|
||||||
|
nth, 0, arr_size
|
||||||
|
);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return get_nth_pelement(nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) {
|
||||||
|
__builtin_memcpy(pelement, pvalue, itemsize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill the ndarray with a value
|
||||||
|
void fill_generic(const uint8_t* pvalue) {
|
||||||
|
const SizeT size = this->size();
|
||||||
|
for (SizeT i = 0; i < size; i++) {
|
||||||
|
uint8_t* pelement = get_nth_pelement(i); // No need for checked_get_nth_pelement
|
||||||
|
set_pelement_value(pelement, pvalue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray->nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray->nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||||
|
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
||||||
|
ndarray_util::assert_shape_no_negative(errctx, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||||
|
ndarray->set_strides_by_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||||
|
ndarray->set_strides_by_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
ndarray->fill_generic(pvalue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
ndarray->fill_generic(pvalue);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray_util {
|
||||||
|
|
||||||
|
// Throw an error if there is an axis with negative dimension
|
||||||
|
template <typename SizeT>
|
||||||
|
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) {
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
if (shape[axis] < 0) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"negative dimensions are not allowed; axis {0} has dimension {1}",
|
||||||
|
axis, shape[axis]
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the size/# of elements of an ndarray given its shape
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||||
|
SizeT size = 1;
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the strides of an ndarray given an ndarray `shape`
|
||||||
|
// and assuming that the ndarray is *fully C-contagious*.
|
||||||
|
//
|
||||||
|
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
int axis = ndims - i - 1;
|
||||||
|
dst_strides[axis] = stride_product * itemsize;
|
||||||
|
stride_product *= shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (int32_t i = 0; i < ndims; i++) {
|
||||||
|
int32_t axis = ndims - i - 1;
|
||||||
|
int32_t dim = shape[axis];
|
||||||
|
|
||||||
|
indices[axis] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
bool can_broadcast_shape_to(
|
||||||
|
const SizeT target_ndims,
|
||||||
|
const SizeT *target_shape,
|
||||||
|
const SizeT src_ndims,
|
||||||
|
const SizeT *src_shape
|
||||||
|
) {
|
||||||
|
/*
|
||||||
|
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
|
||||||
|
This function handles this example:
|
||||||
|
```
|
||||||
|
Image (3d array): 256 x 256 x 3
|
||||||
|
Scale (1d array): 3
|
||||||
|
Result (3d array): 256 x 256 x 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Other interesting examples to consider:
|
||||||
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||||
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||||
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||||
|
|
||||||
|
In cases when the shapes contain zero(es):
|
||||||
|
- `can_broadcast_shape_to([0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0], [2]) == false`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This is essentially doing the following in Python:
|
||||||
|
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||||
|
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||||
|
SizeT target_axis = target_ndims - i - 1;
|
||||||
|
SizeT src_axis = src_ndims - i - 1;
|
||||||
|
|
||||||
|
bool target_dim_exists = target_axis >= 0;
|
||||||
|
bool src_dim_exists = src_axis >= 0;
|
||||||
|
|
||||||
|
SizeT target_dim = target_dim_exists ? target_shape[target_axis] : 1;
|
||||||
|
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
||||||
|
|
||||||
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||||
|
if (!ok) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace string {
|
||||||
|
bool is_empty(const char* str) {
|
||||||
|
return str[0] == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t compare(const char* a, const char* b) {
|
||||||
|
uint32_t i = 0;
|
||||||
|
while (true) {
|
||||||
|
if (a[i] < b[i]) {
|
||||||
|
return -1;
|
||||||
|
} else if (a[i] > b[i]) {
|
||||||
|
return 1;
|
||||||
|
} else { // a[i] == b[i]
|
||||||
|
if (a[i] == '\0') {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t equal(const char* a, const char* b) {
|
||||||
|
return compare(a, b) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t length(const char* str) {
|
||||||
|
uint32_t length = 0;
|
||||||
|
while (*str != '\0') {
|
||||||
|
length++;
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
|
||||||
|
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||||
|
bool is_last = i + 1 == dst_max_size;
|
||||||
|
if (is_last && src[i] != '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src[i] == '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
const T& max(const T& a, const T& b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const T& min(const T& a, const T& b) {
|
||||||
|
return a > b ? b : a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool arrays_match(int len, T* as, T* bs) {
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
if (as[i] != bs[i]) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
uint32_t int_log_floor(T value, T base) {
|
||||||
|
uint32_t result = 0;
|
||||||
|
while (value >= base) {
|
||||||
|
result++;
|
||||||
|
value /= base;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace cstr_utils {
|
||||||
|
bool is_empty(const char* str) {
|
||||||
|
return str[0] == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t compare(const char* a, const char* b) {
|
||||||
|
uint32_t i = 0;
|
||||||
|
while (true) {
|
||||||
|
if (a[i] < b[i]) {
|
||||||
|
return -1;
|
||||||
|
} else if (a[i] > b[i]) {
|
||||||
|
return 1;
|
||||||
|
} else { // a[i] == b[i]
|
||||||
|
if (a[i] == '\0') {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t equal(const char* a, const char* b) {
|
||||||
|
return compare(a, b) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t length(const char* str) {
|
||||||
|
uint32_t length = 0;
|
||||||
|
while (*str != '\0') {
|
||||||
|
length++;
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
|
||||||
|
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||||
|
bool is_last = i + 1 == dst_max_size;
|
||||||
|
if (is_last && src[i] != '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src[i] == '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,216 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt_utils.hpp"
|
|
||||||
#include "irrt_typedefs.hpp"
|
|
||||||
|
|
||||||
/*
|
|
||||||
This header contains IRRT implementations
|
|
||||||
that do not deserved to be categorized (e.g., into numpy, etc.)
|
|
||||||
|
|
||||||
Check out other *.hpp files before including them here!!
|
|
||||||
*/
|
|
||||||
|
|
||||||
// The type of an index or a value describing the length of a range/slice is
|
|
||||||
// always `int32_t`.
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
// need to make sure `exp >= 0` before calling this function
|
|
||||||
template <typename T>
|
|
||||||
T __nac3_int_exp_impl(T base, T exp) {
|
|
||||||
T res = 1;
|
|
||||||
/* repeated squaring method */
|
|
||||||
do {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base; /* for n odd */
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
} while (exp);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
#define DEF_nac3_int_exp_(T) \
|
|
||||||
T __nac3_int_exp_##T(T base, T exp) {\
|
|
||||||
return __nac3_int_exp_impl(base, exp);\
|
|
||||||
}
|
|
||||||
|
|
||||||
DEF_nac3_int_exp_(int32_t)
|
|
||||||
DEF_nac3_int_exp_(int64_t)
|
|
||||||
DEF_nac3_int_exp_(uint32_t)
|
|
||||||
DEF_nac3_int_exp_(uint64_t)
|
|
||||||
|
|
||||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
|
||||||
if (i < 0) {
|
|
||||||
i = len + i;
|
|
||||||
}
|
|
||||||
if (i < 0) {
|
|
||||||
return 0;
|
|
||||||
} else if (i > len) {
|
|
||||||
return len;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
SliceIndex __nac3_range_slice_len(
|
|
||||||
const SliceIndex start,
|
|
||||||
const SliceIndex end,
|
|
||||||
const SliceIndex step
|
|
||||||
) {
|
|
||||||
SliceIndex diff = end - start;
|
|
||||||
if (diff > 0 && step > 0) {
|
|
||||||
return ((diff - 1) / step) + 1;
|
|
||||||
} else if (diff < 0 && step < 0) {
|
|
||||||
return ((diff + 1) / step) + 1;
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle list assignment and dropping part of the list when
|
|
||||||
// both dest_step and src_step are +1.
|
|
||||||
// - All the index must *not* be out-of-bound or negative,
|
|
||||||
// - The end index is *inclusive*,
|
|
||||||
// - The length of src and dest slice size should already
|
|
||||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
|
||||||
SliceIndex __nac3_list_slice_assign_var_size(
|
|
||||||
SliceIndex dest_start,
|
|
||||||
SliceIndex dest_end,
|
|
||||||
SliceIndex dest_step,
|
|
||||||
uint8_t *dest_arr,
|
|
||||||
SliceIndex dest_arr_len,
|
|
||||||
SliceIndex src_start,
|
|
||||||
SliceIndex src_end,
|
|
||||||
SliceIndex src_step,
|
|
||||||
uint8_t *src_arr,
|
|
||||||
SliceIndex src_arr_len,
|
|
||||||
const SliceIndex size
|
|
||||||
) {
|
|
||||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
|
||||||
if (dest_arr_len == 0) return dest_arr_len;
|
|
||||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
|
||||||
if (src_step == dest_step && dest_step == 1) {
|
|
||||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
|
||||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
|
||||||
if (src_len > 0) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_start * size,
|
|
||||||
src_arr + src_start * size,
|
|
||||||
src_len * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (dest_len > 0) {
|
|
||||||
/* dropping */
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + (dest_start + src_len) * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
/* shrink size */
|
|
||||||
return dest_arr_len - (dest_len - src_len);
|
|
||||||
}
|
|
||||||
/* if two range overlaps, need alloca */
|
|
||||||
uint8_t need_alloca =
|
|
||||||
(dest_arr == src_arr)
|
|
||||||
&& !(
|
|
||||||
max(dest_start, dest_end) < min(src_start, src_end)
|
|
||||||
|| max(src_start, src_end) < min(dest_start, dest_end)
|
|
||||||
);
|
|
||||||
if (need_alloca) {
|
|
||||||
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
|
|
||||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
|
||||||
src_arr = tmp;
|
|
||||||
}
|
|
||||||
SliceIndex src_ind = src_start;
|
|
||||||
SliceIndex dest_ind = dest_start;
|
|
||||||
for (;
|
|
||||||
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
|
|
||||||
src_ind += src_step, dest_ind += dest_step
|
|
||||||
) {
|
|
||||||
/* for constant optimization */
|
|
||||||
if (size == 1) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
|
||||||
} else if (size == 4) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
|
||||||
} else if (size == 8) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
|
||||||
} else {
|
|
||||||
/* memcpy for var size, cannot overlap after previous alloca */
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/* only dest_step == 1 can we shrink the dest list. */
|
|
||||||
/* size should be ensured prior to calling this function */
|
|
||||||
if (dest_step == 1 && dest_end >= dest_start) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_ind * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
|
||||||
}
|
|
||||||
return dest_arr_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isinf(double x) {
|
|
||||||
return __builtin_isinf(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isnan(double x) {
|
|
||||||
return __builtin_isnan(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double tgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gamma(double z) {
|
|
||||||
// Handling for denormals
|
|
||||||
// | x | Python gamma(x) | C tgamma(x) |
|
|
||||||
// --- | ----------------- | --------------- | ----------- |
|
|
||||||
// (1) | nan | nan | nan |
|
|
||||||
// (2) | -inf | -inf | inf |
|
|
||||||
// (3) | inf | inf | inf |
|
|
||||||
// (4) | 0.0 | inf | inf |
|
|
||||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
|
||||||
|
|
||||||
// (1)-(3)
|
|
||||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
|
||||||
return z;
|
|
||||||
}
|
|
||||||
|
|
||||||
double v = tgamma(z);
|
|
||||||
|
|
||||||
// (4)-(5)
|
|
||||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
double lgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gammaln(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: gammaln(-inf) -> -inf
|
|
||||||
// - libm : lgamma(-inf) -> inf
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return lgamma(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double j0(double x);
|
|
||||||
|
|
||||||
double __nac3_j0(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: j0(inf) -> nan
|
|
||||||
// - libm : j0(inf) -> 0.0
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return __builtin_nan("");
|
|
||||||
}
|
|
||||||
|
|
||||||
return j0(x);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,14 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "irrt_utils.hpp"
|
#include <irrt/core.hpp>
|
||||||
#include "irrt_typedefs.hpp"
|
#include <irrt/error_context.hpp>
|
||||||
#include "irrt_basic.hpp"
|
#include <irrt/int_defs.hpp>
|
||||||
#include "irrt_slice.hpp"
|
#include <irrt/utils.hpp>
|
||||||
#include "irrt_numpy_ndarray.hpp"
|
#include <irrt/ndarray/ndarray.hpp>
|
||||||
|
|
||||||
/*
|
|
||||||
All IRRT implementations.
|
|
||||||
|
|
||||||
We don't have any pre-compiled objects, so we are writing all implementations in headers and
|
|
||||||
concatenate them with `#include` into one massive source file that contains all the IRRT stuff.
|
|
||||||
*/
|
|
|
@ -1,466 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt_utils.hpp"
|
|
||||||
#include "irrt_typedefs.hpp"
|
|
||||||
#include "irrt_slice.hpp"
|
|
||||||
|
|
||||||
/*
|
|
||||||
NDArray-related implementations.
|
|
||||||
`*/
|
|
||||||
|
|
||||||
// NDArray indices are always `uint32_t`.
|
|
||||||
using NDIndex = uint32_t;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray_util {
|
|
||||||
template <typename SizeT>
|
|
||||||
static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
||||||
for (int32_t i = 0; i < ndims; i++) {
|
|
||||||
int32_t dim_i = ndims - i - 1;
|
|
||||||
int32_t dim = shape[dim_i];
|
|
||||||
|
|
||||||
indices[dim_i] = nth % dim;
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the strides of an ndarray given an ndarray `shape`
|
|
||||||
// and assuming that the ndarray is *fully C-contagious*.
|
|
||||||
//
|
|
||||||
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
|
||||||
template <typename SizeT>
|
|
||||||
static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
int dim_i = ndims - i - 1;
|
|
||||||
dst_strides[dim_i] = stride_product * itemsize;
|
|
||||||
stride_product *= shape[dim_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the size/# of elements of an ndarray given its shape
|
|
||||||
template <typename SizeT>
|
|
||||||
static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
||||||
SizeT size = 1;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
static bool can_broadcast_shape_to(
|
|
||||||
const SizeT target_ndims,
|
|
||||||
const SizeT *target_shape,
|
|
||||||
const SizeT src_ndims,
|
|
||||||
const SizeT *src_shape
|
|
||||||
) {
|
|
||||||
/*
|
|
||||||
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
||||||
|
|
||||||
This function handles this example:
|
|
||||||
```
|
|
||||||
Image (3d array): 256 x 256 x 3
|
|
||||||
Scale (1d array): 3
|
|
||||||
Result (3d array): 256 x 256 x 3
|
|
||||||
```
|
|
||||||
|
|
||||||
Other interesting examples to consider:
|
|
||||||
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
|
||||||
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
|
||||||
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
|
||||||
|
|
||||||
In cases when the shapes contain zero(es):
|
|
||||||
- `can_broadcast_shape_to([0], [1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0], [2]) == false`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
|
||||||
*/
|
|
||||||
|
|
||||||
// This is essentially doing the following in Python:
|
|
||||||
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
|
||||||
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
|
||||||
SizeT target_dim_i = target_ndims - i - 1;
|
|
||||||
SizeT src_dim_i = src_ndims - i - 1;
|
|
||||||
|
|
||||||
bool target_dim_exists = target_dim_i >= 0;
|
|
||||||
bool src_dim_exists = src_dim_i >= 0;
|
|
||||||
|
|
||||||
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
|
|
||||||
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
|
|
||||||
|
|
||||||
bool ok = src_dim == 1 || target_dim == src_dim;
|
|
||||||
if (!ok) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef uint8_t NDSliceType;
|
|
||||||
extern "C" {
|
|
||||||
const NDSliceType INPUT_SLICE_TYPE_INDEX = 0;
|
|
||||||
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct NDSlice {
|
|
||||||
// A poor-man's `std::variant<int, UserRange>`
|
|
||||||
NDSliceType type;
|
|
||||||
|
|
||||||
/*
|
|
||||||
if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT`
|
|
||||||
if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange`
|
|
||||||
*/
|
|
||||||
uint8_t *slice;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace ndarray_util {
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) {
|
|
||||||
irrt_assert(num_slices <= ndims);
|
|
||||||
|
|
||||||
SizeT final_ndims = ndims;
|
|
||||||
for (SizeT i = 0; i < num_slices; i++) {
|
|
||||||
if (slices[i].type == INPUT_SLICE_TYPE_INDEX) {
|
|
||||||
final_ndims--; // An integer slice demotes the rank by 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return final_ndims;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
struct NDArrayIndicesIter {
|
|
||||||
SizeT ndims;
|
|
||||||
const SizeT *shape;
|
|
||||||
SizeT *indices;
|
|
||||||
|
|
||||||
void set_indices_zero() {
|
|
||||||
__builtin_memset(indices, 0, sizeof(SizeT) * ndims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void next() {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT dim_i = ndims - i - 1;
|
|
||||||
|
|
||||||
indices[dim_i]++;
|
|
||||||
if (indices[dim_i] < shape[dim_i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
indices[dim_i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
|
||||||
//
|
|
||||||
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
|
||||||
//
|
|
||||||
// Some resources you might find helpful:
|
|
||||||
// - The official numpy implementations:
|
|
||||||
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
|
||||||
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
|
||||||
// - https://ajcr.net/stride-guide-part-1/.
|
|
||||||
// - https://ajcr.net/stride-guide-part-2/.
|
|
||||||
// - https://ajcr.net/stride-guide-part-3/.
|
|
||||||
template <typename SizeT>
|
|
||||||
struct NDArray {
|
|
||||||
// The underlying data this `ndarray` is pointing to.
|
|
||||||
//
|
|
||||||
// NOTE: Formally this should be of type `void *`, but clang
|
|
||||||
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
|
||||||
// so we will put `uint8_t *` here for clarity.
|
|
||||||
uint8_t *data;
|
|
||||||
|
|
||||||
// The number of bytes of a single element in `data`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
SizeT itemsize;
|
|
||||||
|
|
||||||
// The number of dimensions of this shape.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
SizeT ndims;
|
|
||||||
|
|
||||||
// Array shape, with length equal to `ndims`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `unsigned`.
|
|
||||||
//
|
|
||||||
// NOTE: `shape` can contain 0.
|
|
||||||
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
|
||||||
SizeT *shape;
|
|
||||||
|
|
||||||
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
|
||||||
//
|
|
||||||
// The `SizeT` is treated as `signed`.
|
|
||||||
//
|
|
||||||
// NOTE: `strides` can have negative numbers.
|
|
||||||
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
|
||||||
SizeT *strides;
|
|
||||||
|
|
||||||
// Calculate the size/# of elements of an `ndarray`.
|
|
||||||
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
|
||||||
SizeT size() {
|
|
||||||
return ndarray_util::calc_size_from_shape(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
|
||||||
// This function corresponds to `ndarray.nbytes`
|
|
||||||
SizeT nbytes() {
|
|
||||||
return this->size() * itemsize;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) {
|
|
||||||
__builtin_memcpy(pelement, pvalue, itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* get_pelement(const SizeT *indices) {
|
|
||||||
uint8_t* element = data;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
|
||||||
element += indices[dim_i] * strides[dim_i];
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* get_nth_pelement(SizeT nth) {
|
|
||||||
irrt_assert(0 <= nth);
|
|
||||||
irrt_assert(nth < this->size());
|
|
||||||
|
|
||||||
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
|
||||||
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
|
||||||
return get_pelement(indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get pointer to the first element of this ndarray, assuming
|
|
||||||
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
|
|
||||||
//
|
|
||||||
// This is particularly useful for when the ndarray is just containing a single scalar.
|
|
||||||
uint8_t* get_first_pelement() {
|
|
||||||
irrt_assert(this->size() > 0);
|
|
||||||
return this->data; // ...It is simply `this->data`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is the given `indices` valid/in-bounds?
|
|
||||||
bool in_bounds(const SizeT *indices) {
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
|
||||||
bool dim_ok = indices[dim_i] < shape[dim_i];
|
|
||||||
if (!dim_ok) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill the ndarray with a value
|
|
||||||
void fill_generic(const uint8_t* pvalue) {
|
|
||||||
NDArrayIndicesIter<SizeT> iter;
|
|
||||||
iter.ndims = this->ndims;
|
|
||||||
iter.shape = this->shape;
|
|
||||||
iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims);
|
|
||||||
iter.set_indices_zero();
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
|
|
||||||
uint8_t* pelement = get_pelement(iter.indices);
|
|
||||||
set_value_at_pelement(pelement, pvalue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
|
||||||
void set_strides_by_shape() {
|
|
||||||
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
|
||||||
void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) {
|
|
||||||
__builtin_assume(ndims == 2);
|
|
||||||
|
|
||||||
// TODO: Better implementation
|
|
||||||
|
|
||||||
fill_generic(zero_pvalue);
|
|
||||||
for (SizeT i = 0; i < min(shape[0], shape[1]); i++) {
|
|
||||||
SizeT row = i;
|
|
||||||
SizeT col = i + k;
|
|
||||||
SizeT indices[2] = { row, col };
|
|
||||||
|
|
||||||
if (!in_bounds(indices)) continue;
|
|
||||||
|
|
||||||
uint8_t* pelement = get_pelement(indices);
|
|
||||||
set_value_at_pelement(pelement, one_pvalue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
|
|
||||||
//
|
|
||||||
// Things assumed by this function:
|
|
||||||
// - `dst_ndarray` is allocated by the caller
|
|
||||||
// - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`).
|
|
||||||
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
|
|
||||||
//
|
|
||||||
// Other notes:
|
|
||||||
// - `dst_ndarray->data` does not have to be set, it will be derived.
|
|
||||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize`
|
|
||||||
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
|
||||||
void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
|
||||||
|
|
||||||
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
|
|
||||||
|
|
||||||
dst_ndarray->data = this->data;
|
|
||||||
|
|
||||||
SizeT this_axis = 0;
|
|
||||||
SizeT dst_axis = 0;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_ndslices; i++) {
|
|
||||||
NDSlice *ndslice = &ndslices[i];
|
|
||||||
if (ndslice->type == INPUT_SLICE_TYPE_INDEX) {
|
|
||||||
// Handle when the ndslice is just a single (possibly negative) integer
|
|
||||||
// e.g., `my_array[::2, -5, ::-1]`
|
|
||||||
// ^^------ like this
|
|
||||||
SizeT index_user = *((SizeT*) ndslice->slice);
|
|
||||||
SizeT index = resolve_index_in_length(this->shape[this_axis], index_user);
|
|
||||||
dst_ndarray->data += index * this->strides[this_axis]; // Add offset
|
|
||||||
|
|
||||||
// Next
|
|
||||||
this_axis++;
|
|
||||||
} else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) {
|
|
||||||
// Handle when the ndslice is a slice (represented by UserSlice in IRRT)
|
|
||||||
// e.g., `my_array[::2, -5, ::-1]`
|
|
||||||
// ^^^------^^^^----- like these
|
|
||||||
UserSlice<SizeT>* user_slice = (UserSlice<SizeT>*) ndslice->slice;
|
|
||||||
Slice<SizeT> slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user
|
|
||||||
|
|
||||||
// NOTE: There is no need to write special code to handle negative steps/strides.
|
|
||||||
// This simple implementation meticulously handles both positive and negative steps/strides.
|
|
||||||
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
|
|
||||||
dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
|
|
||||||
dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride
|
|
||||||
dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension
|
|
||||||
|
|
||||||
// Next
|
|
||||||
dst_axis++;
|
|
||||||
this_axis++;
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
|
||||||
// Assumptions:
|
|
||||||
// - `this` has to be fully initialized.
|
|
||||||
// - `dst_ndarray->ndims` has to be set.
|
|
||||||
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
|
|
||||||
//
|
|
||||||
// Other notes:
|
|
||||||
// - `dst_ndarray->data` does not have to be set, it will be set to `this->data`.
|
|
||||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`.
|
|
||||||
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
|
|
||||||
//
|
|
||||||
// Cautions:
|
|
||||||
// ```
|
|
||||||
// xs = np.zeros((4,))
|
|
||||||
// ys = np.zero((4, 1))
|
|
||||||
// ys[:] = xs # ok
|
|
||||||
//
|
|
||||||
// xs = np.zeros((1, 4))
|
|
||||||
// ys = np.zero((4,))
|
|
||||||
// ys[:] = xs # allowed
|
|
||||||
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
|
||||||
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
|
||||||
// # This implementation will NOT support this assignment.
|
|
||||||
// ```
|
|
||||||
void broadcast_to(NDArray<SizeT>* dst_ndarray) {
|
|
||||||
dst_ndarray->data = this->data;
|
|
||||||
dst_ndarray->itemsize = this->itemsize;
|
|
||||||
|
|
||||||
irrt_assert(
|
|
||||||
ndarray_util::can_broadcast_shape_to(
|
|
||||||
dst_ndarray->ndims,
|
|
||||||
dst_ndarray->shape,
|
|
||||||
this->ndims,
|
|
||||||
this->shape
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) {
|
|
||||||
SizeT this_dim_i = this->ndims - i - 1;
|
|
||||||
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
|
||||||
|
|
||||||
bool this_dim_exists = this_dim_i >= 0;
|
|
||||||
bool dst_dim_exists = dst_dim_i >= 0;
|
|
||||||
|
|
||||||
// TODO: Explain how this works
|
|
||||||
bool c1 = this_dim_exists && this->shape[this_dim_i] == 1;
|
|
||||||
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
|
||||||
if (!this_dim_exists || (c1 && c2)) {
|
|
||||||
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
|
||||||
} else {
|
|
||||||
dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize;
|
|
||||||
stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting.
|
|
||||||
// Caution on https://github.com/numpy/numpy/issues/21744
|
|
||||||
// Also see `NDArray::broadcast_to`
|
|
||||||
void assign_with(NDArray<SizeT>* src_ndarray) {
|
|
||||||
irrt_assert(
|
|
||||||
ndarray_util::can_broadcast_shape_to(
|
|
||||||
this->ndims,
|
|
||||||
this->shape,
|
|
||||||
src_ndarray->ndims,
|
|
||||||
src_ndarray->shape
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Broadcast the `src_ndarray` to make the reading process *much* easier
|
|
||||||
SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand
|
|
||||||
NDArray<SizeT> broadcasted_src_ndarray = {
|
|
||||||
.ndims = this->ndims,
|
|
||||||
.shape = this->shape,
|
|
||||||
.strides = broadcasted_src_ndarray_strides
|
|
||||||
};
|
|
||||||
src_ndarray->broadcast_to(&broadcasted_src_ndarray);
|
|
||||||
|
|
||||||
// Using iter instead of `get_nth_pelement` because it is slightly faster
|
|
||||||
SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims);
|
|
||||||
auto iter = NDArrayIndicesIter<SizeT> {
|
|
||||||
.ndims = this->ndims,
|
|
||||||
.shape = this->shape,
|
|
||||||
.indices = indices
|
|
||||||
};
|
|
||||||
const SizeT this_size = this->size();
|
|
||||||
for (SizeT i = 0; i < this_size; i++, iter.next()) {
|
|
||||||
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices);
|
|
||||||
uint8_t* this_pelement = this->get_pelement(indices);
|
|
||||||
this->set_value_at_pelement(src_pelement, src_pelement);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
||||||
return ndarray->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
||||||
return ndarray->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
|
||||||
ndarray->fill_generic(pvalue);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
|
||||||
ndarray->fill_generic(pvalue);
|
|
||||||
}
|
|
||||||
|
|
||||||
// void __nac3_ndarray_slice(NDArray<int32_t>* ndarray, int32_t num_slices, NDSlice<int32_t> *slices, NDArray<int32_t> *dst_ndarray) {
|
|
||||||
// // ndarray->slice(num_slices, slices, dst_ndarray);
|
|
||||||
// }
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt_utils.hpp"
|
|
||||||
#include "irrt_typedefs.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// A proper slice in IRRT, all negative indices have be resolved to absolute values.
|
|
||||||
// Even though nac3core's slices are always `int32_t`, we will template slice anyway
|
|
||||||
// since this struct is used as a general utility.
|
|
||||||
template <typename T>
|
|
||||||
struct Slice {
|
|
||||||
T start;
|
|
||||||
T stop;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
// The length/The number of elements of the slice if it were a range,
|
|
||||||
// i.e., the value of `len(range(this->start, this->stop, this->end))`
|
|
||||||
T len() {
|
|
||||||
T diff = stop - start;
|
|
||||||
if (diff > 0 && step > 0) {
|
|
||||||
return ((diff - 1) / step) + 1;
|
|
||||||
} else if (diff < 0 && step < 0) {
|
|
||||||
return ((diff + 1) / step) + 1;
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
T resolve_index_in_length(T length, T index) {
|
|
||||||
irrt_assert(length >= 0);
|
|
||||||
if (index < 0) {
|
|
||||||
// Remember that index is negative, so do a plus here
|
|
||||||
return max(length + index, 0);
|
|
||||||
} else {
|
|
||||||
return min(length, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: using a bitfield for the `*_defined` is better, at the
|
|
||||||
// cost of a more annoying implementation in nac3core inkwell
|
|
||||||
template <typename T>
|
|
||||||
struct UserSlice {
|
|
||||||
uint8_t start_defined;
|
|
||||||
T start;
|
|
||||||
|
|
||||||
uint8_t stop_defined;
|
|
||||||
T stop;
|
|
||||||
|
|
||||||
uint8_t step_defined;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
// Like Python's `slice(start, stop, step).indices(length)`
|
|
||||||
Slice<T> indices(T length) {
|
|
||||||
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
|
||||||
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
|
||||||
irrt_assert(length >= 0);
|
|
||||||
irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user
|
|
||||||
|
|
||||||
Slice<T> result;
|
|
||||||
result.step = step_defined ? step : 1;
|
|
||||||
bool step_is_negative = result.step < 0;
|
|
||||||
|
|
||||||
if (start_defined) {
|
|
||||||
result.start = resolve_index_in_length(length, start);
|
|
||||||
} else {
|
|
||||||
result.start = step_is_negative ? length - 1 : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stop_defined) {
|
|
||||||
result.stop = resolve_index_in_length(length, stop);
|
|
||||||
} else {
|
|
||||||
result.stop = step_is_negative ? -1 : length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
|
@ -5,654 +5,14 @@
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them
|
#include <irrt_everything.hpp>
|
||||||
#define IRRT_DONT_TYPEDEF_INTS
|
|
||||||
#include "irrt_everything.hpp"
|
|
||||||
|
|
||||||
void test_fail() {
|
#include <test/core.hpp>
|
||||||
printf("[!] Test failed\n");
|
#include <test/test_core.hpp>
|
||||||
exit(1);
|
#include <test/test_utils.hpp>
|
||||||
}
|
|
||||||
|
|
||||||
void __begin_test(const char* function_name, const char* file, int line) {
|
|
||||||
printf("######### Running %s @ %s:%d\n", function_name, file, line);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void debug_print_array(const char* format, int len, T* as) {
|
|
||||||
printf("[");
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
if (i != 0) printf(", ");
|
|
||||||
printf(format, as[i]);
|
|
||||||
}
|
|
||||||
printf("]");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
|
|
||||||
if (!arrays_match(len, expected, got)) {
|
|
||||||
printf(">>>>>>> %s\n", label);
|
|
||||||
printf(" Expecting = ");
|
|
||||||
debug_print_array(format, len, expected);
|
|
||||||
printf("\n");
|
|
||||||
printf(" Got = ");
|
|
||||||
debug_print_array(format, len, got);
|
|
||||||
printf("\n");
|
|
||||||
test_fail();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void assert_values_match(const char* label, const char* format, T expected, T got) {
|
|
||||||
if (expected != got) {
|
|
||||||
printf(">>>>>>> %s\n", label);
|
|
||||||
printf(" Expecting = ");
|
|
||||||
printf(format, expected);
|
|
||||||
printf("\n");
|
|
||||||
printf(" Got = ");
|
|
||||||
printf(format, got);
|
|
||||||
printf("\n");
|
|
||||||
test_fail();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_repeated(const char *str, int count) {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
printf("%s", str);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename SizeT, typename ElementT>
|
|
||||||
void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
|
|
||||||
// A really lazy recursive implementation
|
|
||||||
|
|
||||||
// Add left padding unless its the first entry (since there would be "[[[" before it)
|
|
||||||
if (!first) {
|
|
||||||
print_repeated(" ", depth);
|
|
||||||
}
|
|
||||||
|
|
||||||
const SizeT dim = ndarray->shape[depth];
|
|
||||||
if (depth + 1 == ndarray->ndims) {
|
|
||||||
// Recursed down to last dimension, print the values in a nice list
|
|
||||||
printf("[");
|
|
||||||
|
|
||||||
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
|
|
||||||
for (SizeT i = 0; i < dim; i++) {
|
|
||||||
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
|
|
||||||
ElementT* pelement = (ElementT*) ndarray->get_pelement(indices);
|
|
||||||
ElementT element = *pelement;
|
|
||||||
|
|
||||||
if (i != 0) printf(", "); // List delimiter
|
|
||||||
printf(format, element);
|
|
||||||
printf("(@");
|
|
||||||
debug_print_array("%d", ndarray->ndims, indices);
|
|
||||||
printf(")");
|
|
||||||
|
|
||||||
(*cursor)++;
|
|
||||||
}
|
|
||||||
printf("]");
|
|
||||||
} else {
|
|
||||||
printf("[");
|
|
||||||
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
|
|
||||||
__print_ndarray_aux<SizeT, ElementT>(
|
|
||||||
format,
|
|
||||||
i == 0, // first?
|
|
||||||
i + 1 == dim, // last?
|
|
||||||
cursor,
|
|
||||||
depth + 1,
|
|
||||||
ndarray
|
|
||||||
);
|
|
||||||
}
|
|
||||||
printf("]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add newline unless its the last entry (since there will be "]]]" after it)
|
|
||||||
if (!last) {
|
|
||||||
print_repeated("\n", depth);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename SizeT, typename ElementT>
|
|
||||||
void print_ndarray(const char *format, NDArray<SizeT>* ndarray) {
|
|
||||||
if (ndarray->ndims == 0) {
|
|
||||||
printf("<empty ndarray>");
|
|
||||||
} else {
|
|
||||||
SizeT cursor = 0;
|
|
||||||
__print_ndarray_aux<SizeT, ElementT>(format, true, true, &cursor, 0, ndarray);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_calc_size_from_shape_normal() {
|
|
||||||
// Test shapes with normal values
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
int32_t shape[4] = { 2, 3, 5, 7 };
|
|
||||||
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_calc_size_from_shape_has_zero() {
|
|
||||||
// Test shapes with 0 in them
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
int32_t shape[4] = { 2, 0, 5, 7 };
|
|
||||||
assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_set_strides_by_shape() {
|
|
||||||
// Test `set_strides_by_shape()`
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
int32_t shape[4] = { 99, 3, 5, 7 };
|
|
||||||
int32_t strides[4] = { 0 };
|
|
||||||
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
|
||||||
|
|
||||||
int32_t expected_strides[4] = {
|
|
||||||
105 * sizeof(int32_t),
|
|
||||||
35 * sizeof(int32_t),
|
|
||||||
7 * sizeof(int32_t),
|
|
||||||
1 * sizeof(int32_t)
|
|
||||||
};
|
|
||||||
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndarray_indices_iter_normal() {
|
|
||||||
// Test NDArrayIndicesIter normal behavior
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
int32_t shape[3] = { 1, 2, 3 };
|
|
||||||
int32_t indices[3] = { 0, 0, 0 };
|
|
||||||
auto iter = NDArrayIndicesIter<int32_t> {
|
|
||||||
.ndims = 3,
|
|
||||||
.shape = shape,
|
|
||||||
.indices = indices
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 });
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back
|
|
||||||
iter.next();
|
|
||||||
assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndarray_fill_generic() {
|
|
||||||
// Test ndarray fill_generic
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
// Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up
|
|
||||||
// Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly.
|
|
||||||
uint16_t fill_value = 0xFACE;
|
|
||||||
|
|
||||||
uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999`
|
|
||||||
int32_t in_itemsize = sizeof(uint16_t);
|
|
||||||
const int32_t in_ndims = 2;
|
|
||||||
int32_t in_shape[in_ndims] = { 2, 3 };
|
|
||||||
int32_t in_strides[in_ndims] = {};
|
|
||||||
NDArray<int32_t> ndarray = {
|
|
||||||
.data = (uint8_t*) in_data,
|
|
||||||
.itemsize = in_itemsize,
|
|
||||||
.ndims = in_ndims,
|
|
||||||
.shape = in_shape,
|
|
||||||
.strides = in_strides,
|
|
||||||
};
|
|
||||||
ndarray.set_strides_by_shape();
|
|
||||||
ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here
|
|
||||||
|
|
||||||
uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value };
|
|
||||||
assert_arrays_match("data", "0x%hX", 6, expected_data, in_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndarray_set_to_eye() {
|
|
||||||
// Test `set_to_eye` behavior (helper function to implement `np.eye()`)
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 };
|
|
||||||
int32_t in_itemsize = sizeof(double);
|
|
||||||
const int32_t in_ndims = 2;
|
|
||||||
int32_t in_shape[in_ndims] = { 3, 3 };
|
|
||||||
int32_t in_strides[in_ndims] = {};
|
|
||||||
NDArray<int32_t> ndarray = {
|
|
||||||
.data = (uint8_t*) in_data,
|
|
||||||
.itemsize = in_itemsize,
|
|
||||||
.ndims = in_ndims,
|
|
||||||
.shape = in_shape,
|
|
||||||
.strides = in_strides,
|
|
||||||
};
|
|
||||||
ndarray.set_strides_by_shape();
|
|
||||||
|
|
||||||
double zero = 0.0;
|
|
||||||
double one = 1.0;
|
|
||||||
ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one);
|
|
||||||
|
|
||||||
assert_values_match("in_data[0]", "%f", 0.0, in_data[0]);
|
|
||||||
assert_values_match("in_data[1]", "%f", 1.0, in_data[1]);
|
|
||||||
assert_values_match("in_data[2]", "%f", 0.0, in_data[2]);
|
|
||||||
assert_values_match("in_data[3]", "%f", 0.0, in_data[3]);
|
|
||||||
assert_values_match("in_data[4]", "%f", 0.0, in_data[4]);
|
|
||||||
assert_values_match("in_data[5]", "%f", 1.0, in_data[5]);
|
|
||||||
assert_values_match("in_data[6]", "%f", 0.0, in_data[6]);
|
|
||||||
assert_values_match("in_data[7]", "%f", 0.0, in_data[7]);
|
|
||||||
assert_values_match("in_data[8]", "%f", 0.0, in_data[8]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_slice_1() {
|
|
||||||
// Test `slice(5, None, None).indices(100) == slice(5, 100, 1)`
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
UserSlice<int> user_slice = {
|
|
||||||
.start_defined = 1,
|
|
||||||
.start = 5,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
auto slice = user_slice.indices(100);
|
|
||||||
assert_values_match("start", "%d", 5, slice.start);
|
|
||||||
assert_values_match("stop", "%d", 100, slice.stop);
|
|
||||||
assert_values_match("step", "%d", 1, slice.step);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_slice_2() {
|
|
||||||
// Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)`
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
UserSlice<int> user_slice = {
|
|
||||||
.start_defined = 1,
|
|
||||||
.start = 400,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
auto slice = user_slice.indices(100);
|
|
||||||
assert_values_match("start", "%d", 100, slice.start);
|
|
||||||
assert_values_match("stop", "%d", 100, slice.stop);
|
|
||||||
assert_values_match("step", "%d", 1, slice.step);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_slice_3() {
|
|
||||||
// Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)`
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
UserSlice<int> user_slice = {
|
|
||||||
.start_defined = 1,
|
|
||||||
.start = -10,
|
|
||||||
.stop_defined = 1,
|
|
||||||
.stop = -5,
|
|
||||||
.step_defined = 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
auto slice = user_slice.indices(100);
|
|
||||||
assert_values_match("start", "%d", 90, slice.start);
|
|
||||||
assert_values_match("stop", "%d", 95, slice.stop);
|
|
||||||
assert_values_match("step", "%d", 1, slice.step);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_slice_4() {
|
|
||||||
// Test `slice(None, None, -5).indices(100) == (99, -1, -5)`
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
UserSlice<int> user_slice = {
|
|
||||||
.start_defined = 0,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 1,
|
|
||||||
.step = -5
|
|
||||||
};
|
|
||||||
|
|
||||||
auto slice = user_slice.indices(100);
|
|
||||||
assert_values_match("start", "%d", 99, slice.start);
|
|
||||||
assert_values_match("stop", "%d", -1, slice.stop);
|
|
||||||
assert_values_match("step", "%d", -5, slice.step);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndslice_1() {
|
|
||||||
/*
|
|
||||||
Reference Python code:
|
|
||||||
```python
|
|
||||||
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
|
||||||
# array([[ 0., 1., 2., 3.],
|
|
||||||
# [ 4., 5., 6., 7.],
|
|
||||||
# [ 8., 9., 10., 11.]])
|
|
||||||
|
|
||||||
dst_ndarray = ndarray[-2:, 1::2]
|
|
||||||
# array([[ 5., 7.],
|
|
||||||
# [ 9., 11.]])
|
|
||||||
|
|
||||||
assert dst_ndarray.shape == (2, 2)
|
|
||||||
assert dst_ndarray.strides == (32, 16)
|
|
||||||
assert dst_ndarray[0, 0] == 5.0
|
|
||||||
assert dst_ndarray[0, 1] == 7.0
|
|
||||||
assert dst_ndarray[1, 0] == 9.0
|
|
||||||
assert dst_ndarray[1, 1] == 11.0
|
|
||||||
```
|
|
||||||
*/
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
||||||
int32_t in_itemsize = sizeof(double);
|
|
||||||
const int32_t in_ndims = 2;
|
|
||||||
int32_t in_shape[in_ndims] = { 3, 4 };
|
|
||||||
int32_t in_strides[in_ndims] = {};
|
|
||||||
NDArray<int32_t> ndarray = {
|
|
||||||
.data = (uint8_t*) in_data,
|
|
||||||
.itemsize = in_itemsize,
|
|
||||||
.ndims = in_ndims,
|
|
||||||
.shape = in_shape,
|
|
||||||
.strides = in_strides
|
|
||||||
};
|
|
||||||
ndarray.set_strides_by_shape();
|
|
||||||
|
|
||||||
// Destination ndarray
|
|
||||||
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
|
||||||
const int32_t dst_ndims = 2;
|
|
||||||
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
|
||||||
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
|
||||||
NDArray<int32_t> dst_ndarray = {
|
|
||||||
.data = nullptr,
|
|
||||||
.ndims = dst_ndims,
|
|
||||||
.shape = dst_shape,
|
|
||||||
.strides = dst_strides
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the slice in `ndarray[-2::, 1::2]`
|
|
||||||
UserSlice<int32_t> user_slice_1 = {
|
|
||||||
.start_defined = 1,
|
|
||||||
.start = -2,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 0
|
|
||||||
};
|
|
||||||
|
|
||||||
UserSlice<int32_t> user_slice_2 = {
|
|
||||||
.start_defined = 1,
|
|
||||||
.start = 1,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 1,
|
|
||||||
.step = 2
|
|
||||||
};
|
|
||||||
|
|
||||||
const int32_t num_ndslices = 2;
|
|
||||||
NDSlice ndslices[num_ndslices] = {
|
|
||||||
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
|
|
||||||
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
|
||||||
};
|
|
||||||
|
|
||||||
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
|
||||||
|
|
||||||
int32_t expected_shape[dst_ndims] = { 2, 2 };
|
|
||||||
int32_t expected_strides[dst_ndims] = { 32, 16 };
|
|
||||||
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
|
||||||
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
|
||||||
|
|
||||||
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 })));
|
|
||||||
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 })));
|
|
||||||
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 })));
|
|
||||||
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 })));
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndslice_2() {
|
|
||||||
/*
|
|
||||||
```python
|
|
||||||
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
|
||||||
# array([[ 0., 1., 2., 3.],
|
|
||||||
# [ 4., 5., 6., 7.],
|
|
||||||
# [ 8., 9., 10., 11.]])
|
|
||||||
|
|
||||||
dst_ndarray = ndarray[2, ::-2]
|
|
||||||
# array([11., 9.])
|
|
||||||
|
|
||||||
assert dst_ndarray.shape == (2,)
|
|
||||||
assert dst_ndarray.strides == (-16,)
|
|
||||||
assert dst_ndarray[0] == 11.0
|
|
||||||
assert dst_ndarray[1] == 9.0
|
|
||||||
|
|
||||||
dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray`
|
|
||||||
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
|
|
||||||
```
|
|
||||||
*/
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
|
||||||
int32_t in_itemsize = sizeof(double);
|
|
||||||
const int32_t in_ndims = 2;
|
|
||||||
int32_t in_shape[in_ndims] = { 3, 4 };
|
|
||||||
int32_t in_strides[in_ndims] = {};
|
|
||||||
NDArray<int32_t> ndarray = {
|
|
||||||
.data = (uint8_t*) in_data,
|
|
||||||
.itemsize = in_itemsize,
|
|
||||||
.ndims = in_ndims,
|
|
||||||
.shape = in_shape,
|
|
||||||
.strides = in_strides
|
|
||||||
};
|
|
||||||
ndarray.set_strides_by_shape();
|
|
||||||
|
|
||||||
// Destination ndarray
|
|
||||||
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
|
||||||
const int32_t dst_ndims = 1;
|
|
||||||
int32_t dst_shape[dst_ndims] = {999}; // Empty values
|
|
||||||
int32_t dst_strides[dst_ndims] = {999}; // Empty values
|
|
||||||
NDArray<int32_t> dst_ndarray = {
|
|
||||||
.data = nullptr,
|
|
||||||
.ndims = dst_ndims,
|
|
||||||
.shape = dst_shape,
|
|
||||||
.strides = dst_strides
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the slice in `ndarray[2, ::-2]`
|
|
||||||
int32_t user_slice_1 = 2;
|
|
||||||
UserSlice<int32_t> user_slice_2 = {
|
|
||||||
.start_defined = 0,
|
|
||||||
.stop_defined = 0,
|
|
||||||
.step_defined = 1,
|
|
||||||
.step = -2
|
|
||||||
};
|
|
||||||
|
|
||||||
const int32_t num_ndslices = 2;
|
|
||||||
NDSlice ndslices[num_ndslices] = {
|
|
||||||
{ .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 },
|
|
||||||
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
|
|
||||||
};
|
|
||||||
|
|
||||||
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
|
|
||||||
|
|
||||||
int32_t expected_shape[dst_ndims] = { 2 };
|
|
||||||
int32_t expected_strides[dst_ndims] = { -16 };
|
|
||||||
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
|
|
||||||
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
|
|
||||||
|
|
||||||
// [5.0, 3.0]
|
|
||||||
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 })));
|
|
||||||
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_can_broadcast_shape() {
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([3], [3, 1]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 }));
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([3], [3]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 }));
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([1], [3]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 }));
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([1], [1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 }));
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([256, 256, 3], [3]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([256, 256, 3], [2]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([256, 256, 3], [1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 })
|
|
||||||
);
|
|
||||||
|
|
||||||
// In cases when the shapes contain zero(es)
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([0], [1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([0], [2]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([0, 4, 0, 0], [1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true",
|
|
||||||
"%d",
|
|
||||||
true,
|
|
||||||
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([4, 3], [0, 3]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 })
|
|
||||||
);
|
|
||||||
assert_values_match(
|
|
||||||
"can_broadcast_shape_to([4, 3], [0, 0]) == false",
|
|
||||||
"%d",
|
|
||||||
false,
|
|
||||||
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 })
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ndarray_broadcast_1() {
|
|
||||||
/*
|
|
||||||
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
|
||||||
# >>> [[19.9 29.9 39.9 49.9]]
|
|
||||||
#
|
|
||||||
# array = np.broadcast_to(array, (2, 3, 4))
|
|
||||||
# >>> [[[19.9 29.9 39.9 49.9]
|
|
||||||
# >>> [19.9 29.9 39.9 49.9]
|
|
||||||
# >>> [19.9 29.9 39.9 49.9]]
|
|
||||||
# >>> [[19.9 29.9 39.9 49.9]
|
|
||||||
# >>> [19.9 29.9 39.9 49.9]
|
|
||||||
# >>> [19.9 29.9 39.9 49.9]]]
|
|
||||||
#
|
|
||||||
# assery array.strides == (0, 0, 8)
|
|
||||||
|
|
||||||
*/
|
|
||||||
BEGIN_TEST();
|
|
||||||
|
|
||||||
double in_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
|
||||||
const int32_t in_ndims = 2;
|
|
||||||
int32_t in_shape[in_ndims] = {1, 4};
|
|
||||||
int32_t in_strides[in_ndims] = {};
|
|
||||||
NDArray<int32_t> ndarray = {
|
|
||||||
.data = (uint8_t*) in_data,
|
|
||||||
.itemsize = sizeof(double),
|
|
||||||
.ndims = in_ndims,
|
|
||||||
.shape = in_shape,
|
|
||||||
.strides = in_strides
|
|
||||||
};
|
|
||||||
ndarray.set_strides_by_shape();
|
|
||||||
|
|
||||||
const int32_t dst_ndims = 3;
|
|
||||||
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
|
||||||
int32_t dst_strides[dst_ndims] = {};
|
|
||||||
NDArray<int32_t> dst_ndarray = {
|
|
||||||
.ndims = dst_ndims,
|
|
||||||
.shape = dst_shape,
|
|
||||||
.strides = dst_strides
|
|
||||||
};
|
|
||||||
|
|
||||||
ndarray.broadcast_to(&dst_ndarray);
|
|
||||||
|
|
||||||
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
|
|
||||||
|
|
||||||
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0})));
|
|
||||||
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1})));
|
|
||||||
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2})));
|
|
||||||
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3})));
|
|
||||||
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0})));
|
|
||||||
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1})));
|
|
||||||
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2})));
|
|
||||||
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3})));
|
|
||||||
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3})));
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_assign_with() {
|
|
||||||
/*
|
|
||||||
```
|
|
||||||
xs = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64)
|
|
||||||
ys = xs.shape
|
|
||||||
```
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_calc_size_from_shape_normal();
|
run_test_core();
|
||||||
test_calc_size_from_shape_has_zero();
|
run_test_utils();
|
||||||
test_set_strides_by_shape();
|
|
||||||
test_ndarray_indices_iter_normal();
|
|
||||||
test_ndarray_fill_generic();
|
|
||||||
test_ndarray_set_to_eye();
|
|
||||||
test_slice_1();
|
|
||||||
test_slice_2();
|
|
||||||
test_slice_3();
|
|
||||||
test_slice_4();
|
|
||||||
test_ndslice_1();
|
|
||||||
test_ndslice_2();
|
|
||||||
test_can_broadcast_shape();
|
|
||||||
test_ndarray_broadcast_1();
|
|
||||||
test_assign_with();
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -1,37 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt_typedefs.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <typename T>
|
|
||||||
T max(T a, T b) {
|
|
||||||
return a > b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T min(T a, T b) {
|
|
||||||
return a > b ? b : a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool arrays_match(int len, T *as, T *bs) {
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
if (as[i] != bs[i]) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void irrt_panic() {
|
|
||||||
// Crash the program for now.
|
|
||||||
// TODO: Don't crash the program
|
|
||||||
// ... or at least produce a good message when doing testing IRRT
|
|
||||||
|
|
||||||
uint8_t* death = nullptr;
|
|
||||||
*death = 0; // TODO: address 0 on hardware might be writable?
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make this a macro and allow it to be toggled on/off (e.g., debug vs release)
|
|
||||||
void irrt_assert(bool condition) {
|
|
||||||
if (!condition) irrt_panic();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Include this header for every test_*.cpp
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <test/print.hpp>
|
||||||
|
|
||||||
|
// Some utils can be used here
|
||||||
|
#include "../irrt/utils.hpp"
|
||||||
|
|
||||||
|
void __begin_test(const char* function_name, const char* file, int line) {
|
||||||
|
printf("######### Running %s @ %s:%d\n", function_name, file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
||||||
|
|
||||||
|
void test_fail() {
|
||||||
|
printf("[!] Test failed. Exiting with status code 1.\n");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void debug_print_array(int len, const T* as) {
|
||||||
|
printf("[");
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
if (i != 0) printf(", ");
|
||||||
|
print_value(as[i]);
|
||||||
|
}
|
||||||
|
printf("]");
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_assertion_passed(const char* file, int line) {
|
||||||
|
printf("[*] Assertion passed on %s:%d\n", file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_assertion_failed(const char* file, int line) {
|
||||||
|
printf("[!] Assertion failed on %s:%d\n", file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __assert_true(const char* file, int line, bool cond) {
|
||||||
|
if (cond) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_true(cond) __assert_true(__FILE__, __LINE__, cond)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void __assert_arrays_match(const char* file, int line, int len, const T* expected, const T* got) {
|
||||||
|
if (arrays_match(len, expected, got)) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
printf("Expect = ");
|
||||||
|
debug_print_array(len, expected);
|
||||||
|
printf("\n");
|
||||||
|
printf(" Got = ");
|
||||||
|
debug_print_array(len, got);
|
||||||
|
printf("\n");
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_arrays_match(len, expected, got) __assert_arrays_match(__FILE__, __LINE__, len, expected, got)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void __assert_values_match(const char* file, int line, T expected, T got) {
|
||||||
|
if (expected == got) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
printf("Expect = ");
|
||||||
|
print_value(expected);
|
||||||
|
printf("\n");
|
||||||
|
printf(" Got = ");
|
||||||
|
print_value(got);
|
||||||
|
printf("\n");
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_values_match(expected, got) __assert_values_match(__FILE__, __LINE__, expected, got)
|
|
@ -0,0 +1,42 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void print_value(T value);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(char value) {
|
||||||
|
printf("'%c' (ord=%d)", value, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(int8_t value) {
|
||||||
|
printf("%d", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(int32_t value) {
|
||||||
|
printf("%d", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(uint8_t value) {
|
||||||
|
printf("%u", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(uint32_t value) {
|
||||||
|
printf("%u", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(double value) {
|
||||||
|
printf("%f", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(char* value) {
|
||||||
|
printf("%s", value);
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <test/core.hpp>
|
||||||
|
#include <irrt/core.hpp>
|
||||||
|
|
||||||
|
void test_int_exp() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
assert_values_match(125, __nac3_int_exp_impl<int32_t>(5, 3));
|
||||||
|
assert_values_match(3125, __nac3_int_exp_impl<int32_t>(5, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_test_core() {
|
||||||
|
test_int_exp();
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <test/core.hpp>
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
|
||||||
|
void test_int_log_10() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
assert_values_match((uint32_t) 0, int_log_floor(0, 10));
|
||||||
|
assert_values_match((uint32_t) 0, int_log_floor(9, 10));
|
||||||
|
assert_values_match((uint32_t) 1, int_log_floor(10, 10));
|
||||||
|
assert_values_match((uint32_t) 1, int_log_floor(11, 10));
|
||||||
|
assert_values_match((uint32_t) 1, int_log_floor(99, 10));
|
||||||
|
assert_values_match((uint32_t) 2, int_log_floor(100, 10));
|
||||||
|
assert_values_match((uint32_t) 2, int_log_floor(101, 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_cstr_utils() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
assert_values_match((uint32_t) 42, (uint32_t) cstr_utils::length("THROWN FROM __nac3_error_dummy_raise!!!!!!"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_test_utils() {
|
||||||
|
test_int_log_10();
|
||||||
|
test_cstr_utils();
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,6 +1,8 @@
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, CodeGenContext,
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
CodeGenerator,
|
llvm_intrinsics::call_int_umin,
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use inkwell::context::Context;
|
use inkwell::context::Context;
|
||||||
use inkwell::types::{ArrayType, BasicType, StructType};
|
use inkwell::types::{ArrayType, BasicType, StructType};
|
||||||
|
@ -10,7 +12,6 @@ use inkwell::{
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
/// A LLVM type that is used to represent a non-primitive type in NAC3.
|
/// A LLVM type that is used to represent a non-primitive type in NAC3.
|
||||||
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
||||||
|
@ -1600,8 +1601,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
generator: &G,
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
todo!()
|
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
|
||||||
// call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1675,19 +1675,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||||
indices_elem_ty.get_bit_width()
|
indices_elem_ty.get_bit_width()
|
||||||
);
|
);
|
||||||
|
|
||||||
todo!()
|
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
||||||
|
|
||||||
// let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
// unsafe {
|
.build_in_bounds_gep(
|
||||||
// ctx.builder
|
self.base_ptr(ctx, generator),
|
||||||
// .build_in_bounds_gep(
|
&[index],
|
||||||
// self.base_ptr(ctx, generator),
|
name.unwrap_or_default(),
|
||||||
// &[index],
|
)
|
||||||
// name.unwrap_or_default(),
|
.unwrap()
|
||||||
// )
|
}
|
||||||
// .unwrap()
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -1763,307 +1761,3 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
||||||
for NDArrayDataProxy<'ctx, '_>
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct StructField<'ctx> {
|
|
||||||
/// The GEP index of this struct field.
|
|
||||||
pub gep_index: u32,
|
|
||||||
/// Name of this struct field.
|
|
||||||
///
|
|
||||||
/// Used for generating names.
|
|
||||||
pub name: &'static str,
|
|
||||||
/// The type of this struct field.
|
|
||||||
pub ty: BasicTypeEnum<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct StructFields<'ctx> {
|
|
||||||
/// Name of the struct.
|
|
||||||
///
|
|
||||||
/// Used for generating names.
|
|
||||||
pub name: &'static str,
|
|
||||||
|
|
||||||
/// All the [`StructField`]s of this struct.
|
|
||||||
///
|
|
||||||
/// **NOTE:** The index position of a [`StructField`]
|
|
||||||
/// matches the element's [`StructField::index`].
|
|
||||||
pub fields: Vec<StructField<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct StructFieldsBuilder<'ctx> {
|
|
||||||
gep_index_counter: u32,
|
|
||||||
/// Name of the struct to be built.
|
|
||||||
name: &'static str,
|
|
||||||
fields: Vec<StructField<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> StructField<'ctx> {
|
|
||||||
pub fn gep(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ptr: PointerValue<'ctx>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
ctx.builder.build_struct_gep(ptr, self.gep_index, self.name).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ptr: PointerValue<'ctx>,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
ctx.builder.build_load(self.gep(ctx, ptr), self.name).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>, value: V)
|
|
||||||
where
|
|
||||||
V: BasicValue<'ctx>,
|
|
||||||
{
|
|
||||||
ctx.builder.build_store(ptr, value).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type IsInstanceError = String;
|
|
||||||
type IsInstanceResult = Result<(), IsInstanceError>;
|
|
||||||
|
|
||||||
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
|
|
||||||
where
|
|
||||||
A: BasicType<'ctx>,
|
|
||||||
B: BasicType<'ctx>,
|
|
||||||
{
|
|
||||||
let expected = expected.as_basic_type_enum();
|
|
||||||
let got = got.as_basic_type_enum();
|
|
||||||
|
|
||||||
// Put those logic into here,
|
|
||||||
// otherwise there is always a fallback reporting on any kind of mismatch
|
|
||||||
match (expected, got) {
|
|
||||||
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
|
||||||
if expected.get_bit_width() != got.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(expected, got) => {
|
|
||||||
if expected != got {
|
|
||||||
return Err(format!("Expected {expected}, got {got}"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> StructFields<'ctx> {
|
|
||||||
pub fn num_fields(&self) -> u32 {
|
|
||||||
self.fields.len() as u32
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
|
||||||
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
|
|
||||||
ctx.struct_type(llvm_fields.as_slice(), false)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
|
|
||||||
// Check scrutinee's number of struct fields
|
|
||||||
if scrutinee.count_fields() != self.num_fields() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
|
|
||||||
struct_name = self.name,
|
|
||||||
expected_count = self.num_fields(),
|
|
||||||
got_count = scrutinee.count_fields(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the scrutinee's field types
|
|
||||||
for field in self.fields.iter() {
|
|
||||||
let expected_field_ty = field.ty;
|
|
||||||
let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
|
|
||||||
|
|
||||||
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
|
|
||||||
return Err(format!(
|
|
||||||
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
|
|
||||||
gep_index = field.gep_index,
|
|
||||||
struct_name = self.name,
|
|
||||||
field_name = field.name,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Done
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> StructFieldsBuilder<'ctx> {
|
|
||||||
fn start(name: &'static str) -> Self {
|
|
||||||
StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
|
|
||||||
let index = self.gep_index_counter;
|
|
||||||
self.gep_index_counter += 1;
|
|
||||||
StructField { gep_index: index, name, ty }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn end(self) -> StructFields<'ctx> {
|
|
||||||
StructFields { name: self.name, fields: self.fields }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct NpArrayType<'ctx> {
|
|
||||||
pub size_type: IntType<'ctx>,
|
|
||||||
pub elem_type: BasicTypeEnum<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct NpArrayStructFields<'ctx> {
|
|
||||||
pub whole_struct: StructFields<'ctx>,
|
|
||||||
pub data: StructField<'ctx>,
|
|
||||||
pub itemsize: StructField<'ctx>,
|
|
||||||
pub ndims: StructField<'ctx>,
|
|
||||||
pub shape: StructField<'ctx>,
|
|
||||||
pub strides: StructField<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NpArrayType<'ctx> {
|
|
||||||
pub fn new_opaque_elem(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
size_type: IntType<'ctx>,
|
|
||||||
) -> NpArrayType<'ctx> {
|
|
||||||
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn struct_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructType<'ctx> {
|
|
||||||
self.fields().whole_struct.as_struct_type(ctx.ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fields(&self) -> NpArrayStructFields<'ctx> {
|
|
||||||
let mut builder = StructFieldsBuilder::start("NpArray");
|
|
||||||
|
|
||||||
let addrspace = AddressSpace::default();
|
|
||||||
|
|
||||||
let byte_type = self.size_type.get_context().i8_type();
|
|
||||||
|
|
||||||
// Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`.
|
|
||||||
let data = builder.add_field("data", byte_type.ptr_type(addrspace).into());
|
|
||||||
let itemsize = builder.add_field("itemsize", self.size_type.into());
|
|
||||||
let ndims = builder.add_field("ndims", self.size_type.into());
|
|
||||||
let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into());
|
|
||||||
let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into());
|
|
||||||
|
|
||||||
NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Allocate an `ndarray` on stack, with the following notes:
|
|
||||||
///
|
|
||||||
/// - `ndarray.ndims` will be initialized to `in_ndims`.
|
|
||||||
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
|
|
||||||
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
|
|
||||||
/// all with empty/uninitialized values.
|
|
||||||
pub fn alloca(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
in_ndims: IntValue<'ctx>,
|
|
||||||
name: &str,
|
|
||||||
) -> NpArrayValue<'ctx> {
|
|
||||||
let fields = self.fields();
|
|
||||||
let ptr =
|
|
||||||
ctx.builder.build_alloca(fields.whole_struct.as_struct_type(ctx.ctx), name).unwrap();
|
|
||||||
|
|
||||||
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
|
|
||||||
let allocated_shape =
|
|
||||||
ctx.builder.build_array_alloca(fields.shape.ty, in_ndims, "allocated_shape").unwrap();
|
|
||||||
let allocated_strides = ctx
|
|
||||||
.builder
|
|
||||||
.build_array_alloca(fields.strides.ty, in_ndims, "allocated_strides")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let value = NpArrayValue { ty: *self, ptr };
|
|
||||||
value.store_ndims(ctx, in_ndims);
|
|
||||||
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
|
|
||||||
value.store_shape(ctx, allocated_shape);
|
|
||||||
value.store_strides(ctx, allocated_strides);
|
|
||||||
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct NpArrayValue<'ctx> {
|
|
||||||
pub ty: NpArrayType<'ctx>,
|
|
||||||
pub ptr: PointerValue<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NpArrayValue<'ctx> {
|
|
||||||
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let field = self.ty.fields().ndims;
|
|
||||||
field.load(ctx, self.ptr).into_int_value()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
|
||||||
let field = self.ty.fields().ndims;
|
|
||||||
field.store(ctx, self.ptr, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let field = self.ty.fields().itemsize;
|
|
||||||
field.load(ctx, self.ptr).into_int_value()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
|
||||||
let field = self.ty.fields().itemsize;
|
|
||||||
field.store(ctx, self.ptr, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let field = self.ty.fields().shape;
|
|
||||||
field.load(ctx, self.ptr).into_pointer_value()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
|
||||||
let field = self.ty.fields().shape;
|
|
||||||
field.store(ctx, self.ptr, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let field = self.ty.fields().strides;
|
|
||||||
field.load(ctx, self.ptr).into_pointer_value()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
|
||||||
let field = self.ty.fields().strides;
|
|
||||||
field.store(ctx, self.ptr, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
|
|
||||||
pub fn shape_slice(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
|
||||||
let field = self.ty.fields().shape;
|
|
||||||
field.gep(ctx, self.ptr);
|
|
||||||
|
|
||||||
let ndims = self.load_ndims(ctx);
|
|
||||||
|
|
||||||
TypedArrayLikeAdapter {
|
|
||||||
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
|
|
||||||
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
|
|
||||||
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
|
|
||||||
pub fn strides_slice(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
|
||||||
let field = self.ty.fields().strides;
|
|
||||||
field.gep(ctx, self.ptr);
|
|
||||||
|
|
||||||
let ndims = self.load_ndims(ctx);
|
|
||||||
|
|
||||||
TypedArrayLikeAdapter {
|
|
||||||
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
|
|
||||||
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
|
|
||||||
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,196 @@
|
||||||
|
use inkwell::types::IntType;
|
||||||
|
use inkwell::values::IntValue;
|
||||||
|
|
||||||
|
use crate::codegen::optics::*;
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
use crate::codegen::CodeGenerator;
|
||||||
|
|
||||||
|
use super::util::get_sized_dependent_function_name;
|
||||||
|
use super::util::FunctionBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StrLens<'ctx> {
|
||||||
|
pub size_type: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: nac3core has hardcoded a lot of "str"
|
||||||
|
pub struct StrFields<'ctx> {
|
||||||
|
pub content: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
|
pub length: GepGetter<IntLens<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> StructureOptic<'ctx> for StrLens<'ctx> {
|
||||||
|
type Fields = StrFields<'ctx>;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"str"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||||
|
StrFields {
|
||||||
|
content: builder.add_field("content", AddressLens(IntLens(builder.ctx.i8_type()))),
|
||||||
|
length: builder.add_field("length", IntLens(self.size_type)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ErrorIdsFields<'ctx> {
|
||||||
|
pub index_error: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub value_error: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub assertion_error: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub runtime_error: GepGetter<IntLens<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ErrorIdsLens;
|
||||||
|
|
||||||
|
impl<'ctx> StructureOptic<'ctx> for ErrorIdsLens {
|
||||||
|
type Fields = ErrorIdsFields<'ctx>;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"ErrorIds"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||||
|
let i32_lens = IntLens(builder.ctx.i32_type());
|
||||||
|
ErrorIdsFields {
|
||||||
|
index_error: builder.add_field("index_error", i32_lens),
|
||||||
|
value_error: builder.add_field("value_error", i32_lens),
|
||||||
|
assertion_error: builder.add_field("assertion_error", i32_lens),
|
||||||
|
runtime_error: builder.add_field("runtime_error", i32_lens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ErrorContextFields<'ctx> {
|
||||||
|
pub error_ids: GepGetter<AddressLens<ErrorIdsLens>>,
|
||||||
|
pub error_id: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub message_template: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
|
pub param1: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub param2: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub param3: GepGetter<IntLens<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ErrorContextLens;
|
||||||
|
|
||||||
|
impl<'ctx> StructureOptic<'ctx> for ErrorContextLens {
|
||||||
|
type Fields = ErrorContextFields<'ctx>;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"ErrorContext"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||||
|
ErrorContextFields {
|
||||||
|
error_ids: builder.add_field("error_ids", AddressLens(ErrorIdsLens)),
|
||||||
|
error_id: builder.add_field("error_id", IntLens(builder.ctx.i32_type())),
|
||||||
|
message_template: builder
|
||||||
|
.add_field("message_template", AddressLens(IntLens(builder.ctx.i8_type()))),
|
||||||
|
param1: builder.add_field("param1", IntLens(builder.ctx.i64_type())),
|
||||||
|
param2: builder.add_field("param2", IntLens(builder.ctx.i64_type())),
|
||||||
|
param3: builder.add_field("param3", IntLens(builder.ctx.i64_type())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_error_ids<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, ErrorIdsLens> {
|
||||||
|
// ErrorIdsLens.get_fields(ctx.ctx).assertion_error.
|
||||||
|
let error_ids = ErrorIdsLens.alloca(ctx, "error_ids");
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let get_string_id =
|
||||||
|
|string_id| llvm_i32.const_int(ctx.resolver.get_string_id(string_id) as u64, false);
|
||||||
|
|
||||||
|
error_ids.focus(ctx, |fields| &fields.index_error).store(ctx, &get_string_id("0:IndexError"));
|
||||||
|
error_ids.focus(ctx, |fields| &fields.value_error).store(ctx, &get_string_id("0:ValueError"));
|
||||||
|
error_ids
|
||||||
|
.focus(ctx, |fields| &fields.assertion_error)
|
||||||
|
.store(ctx, &get_string_id("0:AssertionError"));
|
||||||
|
error_ids
|
||||||
|
.focus(ctx, |fields| &fields.runtime_error)
|
||||||
|
.store(ctx, &get_string_id("0:RuntimeError"));
|
||||||
|
|
||||||
|
error_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_initialize<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
errctx: &Address<'ctx, ErrorContextLens>,
|
||||||
|
error_ids: &Address<'ctx, ErrorIdsLens>,
|
||||||
|
) {
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_context_initialize")
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
|
||||||
|
.arg("error_ids", &AddressLens(ErrorIdsLens), error_ids)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_has_no_error<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
errctx: &Address<'ctx, ErrorContextLens>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error")
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
|
||||||
|
.returning("has_error", &IntLens(ctx.ctx.bool_type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_get_error_str<'ctx>(
|
||||||
|
size_type: IntType<'ctx>,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
errctx: &Address<'ctx, ErrorContextLens>,
|
||||||
|
dst_str: &Address<'ctx, StrLens<'ctx>>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(size_type, "__nac3_error_context_get_error_str"),
|
||||||
|
)
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
|
||||||
|
.arg("dst_str", &AddressLens(StrLens { size_type }), dst_str)
|
||||||
|
.returning("has_error", &IntLens(ctx.ctx.bool_type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prepare_error_context<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Address<'ctx, ErrorContextLens> {
|
||||||
|
let error_ids = build_error_ids(ctx);
|
||||||
|
let errctx_ptr = ErrorContextLens.alloca(ctx, "errctx");
|
||||||
|
call_nac3_error_context_initialize(ctx, &errctx_ptr, &error_ids);
|
||||||
|
errctx_ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
errctx_ptr: &Address<'ctx, ErrorContextLens>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr);
|
||||||
|
let error_str_ptr = StrLens { size_type }.alloca(ctx, "error_str");
|
||||||
|
call_nac3_error_context_get_error_str(size_type, ctx, errctx_ptr, &error_str_ptr);
|
||||||
|
|
||||||
|
let error_id = errctx_ptr.focus(ctx, |fields| &fields.error_id).load(ctx, "error_id");
|
||||||
|
let error_str = error_str_ptr.load(ctx, "error_str");
|
||||||
|
let param1 = errctx_ptr.focus(ctx, |fields| &fields.param1).load(ctx, "param1");
|
||||||
|
let param2 = errctx_ptr.focus(ctx, |fields| &fields.param2).load(ctx, "param2");
|
||||||
|
let param3 = errctx_ptr.focus(ctx, |fields| &fields.param3).load(ctx, "param3");
|
||||||
|
ctx.make_assert_impl_by_id(
|
||||||
|
generator,
|
||||||
|
has_error,
|
||||||
|
error_id,
|
||||||
|
error_str.get_llvm_value(),
|
||||||
|
[Some(param1), Some(param2), Some(param3)],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_dummy_raise<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext,
|
||||||
|
) {
|
||||||
|
let errctx = prepare_error_context(ctx);
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise")
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), &errctx)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, &errctx);
|
||||||
|
}
|
|
@ -1,11 +1,9 @@
|
||||||
use crate::{typecheck::typedef::Type, util::SizeVariant};
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
mod test;
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayType,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||||
NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
@ -16,13 +14,18 @@ use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::Module,
|
module::Module,
|
||||||
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
|
types::{BasicTypeEnum, IntType},
|
||||||
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
|
pub mod error_context;
|
||||||
|
pub mod numpy;
|
||||||
|
pub mod test;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt(ctx: &Context) -> Module {
|
pub fn load_irrt(ctx: &Context) -> Module {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
|
@ -929,63 +932,3 @@ pub fn call_ndarray_calc_broadcast_index<
|
||||||
Box::new(|_, v| v.into()),
|
Box::new(|_, v| v.into()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
|
|
||||||
match ty.get_bit_width() {
|
|
||||||
32 => SizeVariant::Bits32,
|
|
||||||
64 => SizeVariant::Bits64,
|
|
||||||
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
size_type: IntType<'ctx>,
|
|
||||||
base_name: &str,
|
|
||||||
build_func_type: BuildFuncTypeFn,
|
|
||||||
) -> FunctionValue<'ctx>
|
|
||||||
where
|
|
||||||
BuildFuncTypeFn: Fn() -> FunctionType<'ctx>,
|
|
||||||
{
|
|
||||||
let mut fn_name = base_name.to_owned();
|
|
||||||
match get_size_variant(size_type) {
|
|
||||||
SizeVariant::Bits32 => {
|
|
||||||
// The original fn_name is the correct function name
|
|
||||||
}
|
|
||||||
SizeVariant::Bits64 => {
|
|
||||||
// Append "64" at the end, this is the naming convention for 64-bit
|
|
||||||
fn_name.push_str("64");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get (or declare then get if does not exist) the corresponding function
|
|
||||||
ctx.module.get_function(&fn_name).unwrap_or_else(|| {
|
|
||||||
let fn_type = build_func_type();
|
|
||||||
ctx.module.add_function(&fn_name, fn_type, None)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> {
|
|
||||||
let i8_type = ctx.i8_type();
|
|
||||||
|
|
||||||
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
|
|
||||||
let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx);
|
|
||||||
struct_ty.ptr_type(AddressSpace::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_size<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NpArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let size_type = ndarray.ty.size_type;
|
|
||||||
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
|
|
||||||
size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(function, &[ndarray.ptr.into()], "size")
|
|
||||||
.unwrap()
|
|
||||||
.try_as_basic_value()
|
|
||||||
.unwrap_left()
|
|
||||||
.into_int_value()
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,415 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use inkwell::{
|
||||||
|
types::{BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::optics::*;
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
classes::{ListValue, UntypedArrayLikeAccessor},
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
error_context::{check_error_context, prepare_error_context, ErrorContextLens},
|
||||||
|
util::{get_sized_dependent_function_name, FunctionBuilder},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct NpArrayFields<'ctx> {
|
||||||
|
pub data: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
|
pub itemsize: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub ndims: GepGetter<IntLens<'ctx>>,
|
||||||
|
pub shape: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
|
pub strides: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct NpArrayLens<'ctx> {
|
||||||
|
pub size_type: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> {
|
||||||
|
type Fields = NpArrayFields<'ctx>;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"NDArray"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||||
|
NpArrayFields {
|
||||||
|
data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))),
|
||||||
|
itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())),
|
||||||
|
ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())),
|
||||||
|
shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))),
|
||||||
|
strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other convenient utilities for NpArray
|
||||||
|
impl<'ctx> Address<'ctx, NpArrayLens<'ctx>> {
|
||||||
|
pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> {
|
||||||
|
let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims");
|
||||||
|
let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape");
|
||||||
|
ArraySlice { num_elements: ndims, base: shape_base_ptr }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn strides_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> {
|
||||||
|
let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims");
|
||||||
|
let strides_base_ptr = self.focus(ctx, |fields| &fields.strides).load(ctx, "strides");
|
||||||
|
ArraySlice { num_elements: ndims, base: strides_base_ptr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProducerWriteToArray<'ctx, G, ElementOptic> = Box<
|
||||||
|
dyn Fn(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, '_>,
|
||||||
|
&ArraySlice<'ctx, ElementOptic>,
|
||||||
|
) -> Result<(), String>
|
||||||
|
+ 'ctx,
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct Producer<'ctx, G: CodeGenerator + ?Sized, ElementOptic> {
|
||||||
|
pub count: IntValue<'ctx>,
|
||||||
|
pub write_to_array: ProducerWriteToArray<'ctx, G, ElementOptic>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: UPDATE DOCUMENTATION
|
||||||
|
/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
|
///
|
||||||
|
/// ### Notes on `shape`
|
||||||
|
///
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
||||||
|
/// learn how `shape` gets from being a Python user expression to here.
|
||||||
|
pub fn parse_input_shape_arg<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
) -> Producer<'ctx, G, IntLens<'ctx>>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(shape_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
// A list has to be a PointerValue
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None);
|
||||||
|
|
||||||
|
// Create `Producer`
|
||||||
|
let ndims = shape_list.load_size(ctx, Some("count"));
|
||||||
|
Producer {
|
||||||
|
count: ndims,
|
||||||
|
write_to_array: Box::new(move |ctx, generator, dst_array| {
|
||||||
|
// Basically iterate through the list and write to `dst_slice` accordingly
|
||||||
|
let init_val = size_type.const_zero();
|
||||||
|
let max_val = (ndims, false);
|
||||||
|
let incr_val = size_type.const_int(1, false);
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
init_val,
|
||||||
|
max_val,
|
||||||
|
|generator, ctx, _hooks, axis| {
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim =
|
||||||
|
shape_list.data().get(ctx, generator, &axis, None).into_int_value();
|
||||||
|
|
||||||
|
// Cast `dim` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array.ix(ctx, axis, "dim").store(ctx, &dim);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
incr_val,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty: tuple_types } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
|
let ndims = tuple_types.len();
|
||||||
|
|
||||||
|
// A tuple has to be a StructValue
|
||||||
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||||
|
let shape_tuple = shape.into_struct_value();
|
||||||
|
|
||||||
|
Producer {
|
||||||
|
count: size_type.const_int(ndims as u64, false),
|
||||||
|
write_to_array: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
for axis in 0..ndims {
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(
|
||||||
|
shape_tuple,
|
||||||
|
axis as u32,
|
||||||
|
format!("dim{axis}").as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
|
||||||
|
// Cast `dim` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array
|
||||||
|
.ix(ctx, size_type.const_int(axis as u64, false), "dim")
|
||||||
|
.store(ctx, &dim);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
// The value has to be an integer
|
||||||
|
let shape_int = shape.into_int_value();
|
||||||
|
|
||||||
|
Producer {
|
||||||
|
count: size_type.const_int(1, false),
|
||||||
|
write_to_array: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
// Cast `shape_int` to SizeT
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Write
|
||||||
|
dst_array
|
||||||
|
.ix(ctx, size_type.const_zero() /* Only index 0 is set */, "dim")
|
||||||
|
.store(ctx, &dim);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("parse_input_shape_arg encountered unknown type"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloca_ndarray<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_type: BasicTypeEnum<'ctx>,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Address<'ctx, NpArrayLens<'ctx>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
// Allocate ndarray
|
||||||
|
let ndarray_ptr = NpArrayLens { size_type }.alloca(ctx, name);
|
||||||
|
|
||||||
|
// Set ndims
|
||||||
|
ndarray_ptr.focus(ctx, |fields| &fields.ndims).store(ctx, &ndims);
|
||||||
|
|
||||||
|
// Set itemsize
|
||||||
|
let itemsize = elem_type.size_of().unwrap();
|
||||||
|
let itemsize =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(itemsize, size_type, "itemsize").unwrap();
|
||||||
|
ndarray_ptr.focus(ctx, |fields| &fields.itemsize).store(ctx, &itemsize);
|
||||||
|
|
||||||
|
// Allocate and set shape
|
||||||
|
let shape_ptr = ctx.builder.build_array_alloca(size_type, ndims, "shape").unwrap();
|
||||||
|
ndarray_ptr
|
||||||
|
.focus(ctx, |fields| &fields.shape)
|
||||||
|
.store(ctx, &Address { addressee_optic: IntLens(size_type), address: shape_ptr });
|
||||||
|
|
||||||
|
// Allocate and set strides
|
||||||
|
let strides_ptr = ctx.builder.build_array_alloca(size_type, ndims, "strides").unwrap();
|
||||||
|
ndarray_ptr
|
||||||
|
.focus(ctx, |fields| &fields.strides)
|
||||||
|
.store(ctx, &Address { addressee_optic: IntLens(size_type), address: strides_ptr });
|
||||||
|
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum NDArrayInitMode<'ctx, G: CodeGenerator + ?Sized> {
|
||||||
|
NDim { ndim: IntValue<'ctx> },
|
||||||
|
Shape { shape: Producer<'ctx, G, IntLens<'ctx>> },
|
||||||
|
ShapeAndAllocaData { shape: Producer<'ctx, G, IntLens<'ctx>> },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: DOCUMENT ME
|
||||||
|
pub fn alloca_ndarray_and_init<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_type: BasicTypeEnum<'ctx>,
|
||||||
|
init_mode: NDArrayInitMode<'ctx, G>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Address<'ctx, NpArrayLens<'ctx>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
// It is implemented verbosely in order to make the initialization modes super clear in their intent.
|
||||||
|
match init_mode {
|
||||||
|
NDArrayInitMode::NDim { ndim } => {
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
NDArrayInitMode::Shape { shape } => {
|
||||||
|
let ndims = shape.count;
|
||||||
|
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
|
|
||||||
|
// Fill `ndarray.shape`
|
||||||
|
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_array(ctx))?;
|
||||||
|
|
||||||
|
// Check if `shape` has bad inputs
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndims,
|
||||||
|
&ndarray_ptr.focus(ctx, |fields| &fields.shape).load(ctx, "shape"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// NOTE: DO NOT DO `set_strides_by_shape` HERE.
|
||||||
|
// Simply this is because we specified that `SetShape` wouldn't do `set_strides_by_shape`
|
||||||
|
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
NDArrayInitMode::ShapeAndAllocaData { shape } => {
|
||||||
|
let ndims = shape.count;
|
||||||
|
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
|
|
||||||
|
// Fill `ndarray.shape`
|
||||||
|
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_array(ctx))?;
|
||||||
|
|
||||||
|
// Check if `shape` has bad inputs
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndims,
|
||||||
|
&ndarray_ptr.focus(ctx, |fields| &fields.shape).load(ctx, "shape"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Now we populate `ndarray.data` by alloca-ing.
|
||||||
|
// But first, we need to know the size of the ndarray to know how many elements to alloca,
|
||||||
|
// since calculating nbytes of an ndarray requires `ndarray.shape` to be set.
|
||||||
|
let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray_ptr);
|
||||||
|
|
||||||
|
// Alloca `data` and assign it to `ndarray.data`
|
||||||
|
let data_ptr =
|
||||||
|
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes, "data").unwrap();
|
||||||
|
ndarray_ptr.focus(ctx, |fields| &fields.data).store(
|
||||||
|
ctx,
|
||||||
|
&Address { addressee_optic: IntLens::int8(ctx.ctx), address: data_ptr },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Finally, do `set_strides_by_shape`
|
||||||
|
// Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are.
|
||||||
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray_ptr);
|
||||||
|
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
shape_ptr: &Address<'ctx, IntLens<'ctx>>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let errctx = prepare_error_context(ctx);
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("errctx", &AddressLens(ErrorContextLens), &errctx)
|
||||||
|
.arg("ndims", &IntLens(size_type), &ndims)
|
||||||
|
.arg("shape", &AddressLens(IntLens(size_type)), shape_ptr)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, &errctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr)
|
||||||
|
.returning("nbytes", &IntLens(size_type))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>,
|
||||||
|
fill_value_ptr: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(size_type, "__nac3_ndarray_fill_generic"),
|
||||||
|
)
|
||||||
|
.arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr)
|
||||||
|
.arg("pvalue", &OpaqueAddressLens, fill_value_ptr)
|
||||||
|
.returning_void();
|
||||||
|
}
|
|
@ -11,7 +11,6 @@ mod tests {
|
||||||
|
|
||||||
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
|
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
|
||||||
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
|
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
|
||||||
|
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
eprintln!("irrt_test failed with status {}:", output.status);
|
eprintln!("irrt_test failed with status {}:", output.status);
|
||||||
eprintln!("====== stdout ======");
|
eprintln!("====== stdout ======");
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::{BasicMetadataTypeEnum, BasicType, IntType},
|
||||||
|
values::{AnyValue, BasicMetadataValueEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::optics::*;
|
||||||
|
use crate::{codegen::CodeGenContext, util::SizeVariant};
|
||||||
|
|
||||||
|
fn get_size_variant(ty: IntType) -> SizeVariant {
|
||||||
|
match ty.get_bit_width() {
|
||||||
|
32 => SizeVariant::Bits32,
|
||||||
|
64 => SizeVariant::Bits64,
|
||||||
|
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String {
|
||||||
|
let mut fn_name = fn_name.to_owned();
|
||||||
|
match get_size_variant(ty) {
|
||||||
|
SizeVariant::Bits32 => {
|
||||||
|
// Do nothing, `fn_name` already has the correct name
|
||||||
|
}
|
||||||
|
SizeVariant::Bits64 => {
|
||||||
|
// Append "64", this is the naming convention
|
||||||
|
fn_name.push_str("64");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn_name
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Variadic argument?
|
||||||
|
pub struct FunctionBuilder<'ctx, 'a> {
|
||||||
|
ctx: &'a CodeGenContext<'ctx, 'a>,
|
||||||
|
fn_name: &'a str,
|
||||||
|
arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||||
|
pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self {
|
||||||
|
FunctionBuilder { ctx, fn_name, arguments: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name is for self-documentation
|
||||||
|
#[must_use]
|
||||||
|
pub fn arg<S: MemoryOptic<'ctx>>(mut self, _name: &'static str, optic: &S, arg: &S::MemoryValue) -> Self {
|
||||||
|
self.arguments
|
||||||
|
.push((optic.get_llvm_type(self.ctx.ctx).into(), arg.get_llvm_value().into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn returning<S: Prism<'ctx>>(self, name: &'static str, return_prism: &S) -> S::MemoryValue {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = return_prism.get_llvm_type(self.ctx.ctx);
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
||||||
|
return_prism.review(ret.as_any_value_enum())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
||||||
|
pub fn returning_void(self) {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = self.ctx.ctx.void_type();
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
self.ctx.builder.build_call(function, ¶m_vals, "").unwrap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,6 +35,54 @@ fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`llvm.lifetime.start`](https://releases.llvm.org/14.0.0/docs/LangRef.html#llvm-lifetime-start-intrinsic)
|
||||||
|
/// intrinsic.
|
||||||
|
pub fn call_lifetime_start<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
const FN_NAME: &str = "llvm.lifetime.start";
|
||||||
|
// NOTE: inkwell temporary workaround, see [`call_stackrestore`] for details
|
||||||
|
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
let llvm_p0i8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let fn_type = llvm_void.fn_type(&[llvm_i64.into(), llvm_p0i8.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function(FN_NAME, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[size.into(), ptr.into()], "")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`llvm.lifetime.end`](https://releases.llvm.org/14.0.0/docs/LangRef.html#llvm-lifetime-end-intrinsic)
|
||||||
|
/// intrinsic.
|
||||||
|
pub fn call_lifetime_end<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
const FN_NAME: &str = "llvm.lifetime.end";
|
||||||
|
// NOTE: inkwell temporary workaround, see [`call_stackrestore`] for details
|
||||||
|
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
let llvm_p0i8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let fn_type = llvm_void.fn_type(&[llvm_i64.into(), llvm_p0i8.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function(FN_NAME, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[size.into(), ptr.into()], "")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
|
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
|
||||||
/// intrinsic.
|
/// intrinsic.
|
||||||
pub fn call_stacksave<'ctx>(
|
pub fn call_stacksave<'ctx>(
|
||||||
|
|
|
@ -23,8 +23,10 @@ use inkwell::{
|
||||||
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use irrt::error_context::StrLens;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::{Location, Stmt, StrRef};
|
use nac3parser::ast::{Location, Stmt, StrRef};
|
||||||
|
use optics::MemoryOptic as _;
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
|
@ -42,6 +44,8 @@ mod generator;
|
||||||
pub mod irrt;
|
pub mod irrt;
|
||||||
pub mod llvm_intrinsics;
|
pub mod llvm_intrinsics;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
|
pub mod numpy_new;
|
||||||
|
pub mod optics;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -646,6 +650,8 @@ pub fn gen_func_impl<
|
||||||
..primitives
|
..primitives
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let llvm_str_ty =
|
||||||
|
StrLens { size_type: generator.get_size_type(context) }.get_llvm_type(context);
|
||||||
let mut type_cache: HashMap<_, _> = [
|
let mut type_cache: HashMap<_, _> = [
|
||||||
(primitives.int32, context.i32_type().into()),
|
(primitives.int32, context.i32_type().into()),
|
||||||
(primitives.int64, context.i64_type().into()),
|
(primitives.int64, context.i64_type().into()),
|
||||||
|
@ -653,21 +659,7 @@ pub fn gen_func_impl<
|
||||||
(primitives.uint64, context.i64_type().into()),
|
(primitives.uint64, context.i64_type().into()),
|
||||||
(primitives.float, context.f64_type().into()),
|
(primitives.float, context.f64_type().into()),
|
||||||
(primitives.bool, context.i8_type().into()),
|
(primitives.bool, context.i8_type().into()),
|
||||||
(primitives.str, {
|
(primitives.str, llvm_str_ty),
|
||||||
let name = "str";
|
|
||||||
match module.get_struct_type(name) {
|
|
||||||
None => {
|
|
||||||
let str_type = context.opaque_struct_type("str");
|
|
||||||
let fields = [
|
|
||||||
context.i8_type().ptr_type(AddressSpace::default()).into(),
|
|
||||||
generator.get_size_type(context).into(),
|
|
||||||
];
|
|
||||||
str_type.set_body(&fields, false);
|
|
||||||
str_type.into()
|
|
||||||
}
|
|
||||||
Some(t) => t.as_basic_type_enum(),
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
(primitives.range, RangeType::new(context).as_base_type().into()),
|
(primitives.range, RangeType::new(context).as_base_type().into()),
|
||||||
(primitives.exception, {
|
(primitives.exception, {
|
||||||
let name = "Exception";
|
let name = "Exception";
|
||||||
|
@ -677,7 +669,7 @@ pub fn gen_func_impl<
|
||||||
let exception = context.opaque_struct_type("Exception");
|
let exception = context.opaque_struct_type("Exception");
|
||||||
let int32 = context.i32_type().into();
|
let int32 = context.i32_type().into();
|
||||||
let int64 = context.i64_type().into();
|
let int64 = context.i64_type().into();
|
||||||
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
|
let str_ty = llvm_str_ty;
|
||||||
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
|
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
|
||||||
exception.set_body(&fields, false);
|
exception.set_body(&fields, false);
|
||||||
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
|
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,91 @@
|
||||||
|
use inkwell::values::{BasicValueEnum, PointerValue};
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::optics::build_opaque_alloca, symbol_resolver::ValueEnum, toplevel::DefinitionId, typecheck::typedef::{FunSignature, Type}
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
irrt::{
|
||||||
|
self,
|
||||||
|
numpy::{alloca_ndarray_and_init, parse_input_shape_arg, NDArrayInitMode, NpArrayLens},
|
||||||
|
},
|
||||||
|
optics::Address,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`.
|
||||||
|
fn call_ndarray_empty_impl<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Address<'ctx, NpArrayLens<'ctx>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let elem_type = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let shape = parse_input_shape_arg(generator, ctx, shape, shape_ty);
|
||||||
|
let ndarray_ptr = alloca_ndarray_and_init(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_type,
|
||||||
|
NDArrayInitMode::ShapeAndAllocaData { shape },
|
||||||
|
name,
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_ndarray_fill_impl<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Address<'ctx, NpArrayLens<'ctx>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let ndarray_ptr = call_ndarray_empty_impl(generator, ctx, elem_ty, shape, shape_ty, name)?;
|
||||||
|
|
||||||
|
// NOTE: fill_value's type is not checked!!
|
||||||
|
let fill_value_ptr = build_opaque_alloca(ctx, fill_value.get_type(), name);
|
||||||
|
fill_value_ptr.store(ctx, );
|
||||||
|
// let fill_value_ptr = ctx.builder.build_alloca(, "fill_value_ptr").unwrap();
|
||||||
|
// ctx.builder.build_store(fill_value_ptr, fill_value);
|
||||||
|
|
||||||
|
// let ok = irrt::numpy::call_nac3_ndarray_fill_generic(generator, ctx, ndarray_ptr, Address { fill_value_ptr } );
|
||||||
|
todo!()
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.empty`.
|
||||||
|
pub fn gen_ndarray_empty<'ctx, G>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray_ptr = call_ndarray_empty_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
"empty_ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.address)
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::{AnyValue, BasicValue, BasicValueEnum, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::core::{MemoryGetter, MemoryOptic, MemorySetter, OpticValue, Prism};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Address<'ctx, AddresseeOptic> {
|
||||||
|
pub addressee_optic: AddresseeOptic,
|
||||||
|
pub address: PointerValue<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> {
|
||||||
|
pub fn cast_to<S: MemoryOptic<'ctx>>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
new_optic: S,
|
||||||
|
) -> Address<'ctx, S> {
|
||||||
|
let to_ptr_type = new_optic.get_llvm_type(ctx.ctx).ptr_type(AddressSpace::default());
|
||||||
|
let casted_address =
|
||||||
|
ctx.builder.build_pointer_cast(self.address, to_ptr_type, "ptr_casted").unwrap();
|
||||||
|
Address { addressee_optic: new_optic, address: casted_address }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cast_to_opaque(&self) -> OpaqueAddress<'ctx> {
|
||||||
|
OpaqueAddress(self.address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddresseeOptic> OpticValue<'ctx> for Address<'ctx, AddresseeOptic> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.address.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AddressLens<AddresseeOptic>(pub AddresseeOptic);
|
||||||
|
|
||||||
|
impl<'ctx, AddresseeOptic: MemoryOptic<'ctx>> MemoryOptic<'ctx> for AddressLens<AddresseeOptic> {
|
||||||
|
type MemoryValue = Address<'ctx, AddresseeOptic>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddresseeOptic: MemoryOptic<'ctx>> Prism<'ctx> for AddressLens<AddresseeOptic> {
|
||||||
|
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::MemoryValue {
|
||||||
|
Address {
|
||||||
|
addressee_optic: self.0.clone(),
|
||||||
|
address: value.as_any_value_enum().into_pointer_value(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddressesOptic: MemoryOptic<'ctx>> MemoryGetter<'ctx> for AddressLens<AddressesOptic> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Self::MemoryValue {
|
||||||
|
self.review(ctx.builder.build_load(pointer, name).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddressesOptic: MemoryOptic<'ctx>> MemorySetter<'ctx> for AddressLens<AddressesOptic> {
|
||||||
|
fn set(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
value: &Self::MemoryValue,
|
||||||
|
) {
|
||||||
|
ctx.builder.build_store(pointer, value.address).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To make [`Address`] convenient to use
|
||||||
|
impl<'ctx, AddresseeOptic: MemoryGetter<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||||
|
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> AddresseeOptic::MemoryValue {
|
||||||
|
self.addressee_optic.get(ctx, self.address, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To make [`Address`] convenient to use
|
||||||
|
impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||||
|
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: &AddresseeOptic::MemoryValue) {
|
||||||
|
self.addressee_optic.set(ctx, self.address, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct OpaqueAddress<'ctx>(pub PointerValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> OpaqueAddress<'ctx> {
|
||||||
|
pub fn cast_to<AddresseeOptic: MemoryOptic<'ctx>>(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx CodeGenContext,
|
||||||
|
addressee_optic: AddresseeOptic,
|
||||||
|
name: &str,
|
||||||
|
) -> Address<'ctx, AddresseeOptic> {
|
||||||
|
let ptr = ctx.builder.build_pointer_cast(
|
||||||
|
self.0,
|
||||||
|
addressee_optic.get_llvm_type(ctx.ctx).ptr_type(AddressSpace::default()),
|
||||||
|
name,
|
||||||
|
);
|
||||||
|
Address { addressee_optic, address: ptr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct OpaqueAddressLens;
|
||||||
|
|
||||||
|
impl<'ctx> MemoryOptic<'ctx> for OpaqueAddressLens {
|
||||||
|
type MemoryValue = BasicValueEnum<'ctx>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> OpaqueAddress<'ctx> {
|
||||||
|
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) {
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::BasicTypeEnum,
|
||||||
|
values::{AnyValue, BasicValue, BasicValueEnum, PointerValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::address::Address;
|
||||||
|
|
||||||
|
// TODO: Write a taxonomy
|
||||||
|
|
||||||
|
pub trait OpticValue<'ctx> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T: BasicValue<'ctx>> OpticValue<'ctx> for T {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: The interface is unintuitive
|
||||||
|
pub trait MemoryOptic<'ctx>: Clone {
|
||||||
|
type MemoryValue: OpticValue<'ctx>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||||
|
|
||||||
|
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Address<'ctx, Self> {
|
||||||
|
let ptr = ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap();
|
||||||
|
Address { addressee_optic: self.clone(), address: ptr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Prism<'ctx>: MemoryOptic<'ctx> {
|
||||||
|
// TODO: Return error if `review` fails
|
||||||
|
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::MemoryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MemoryGetter<'ctx>: MemoryOptic<'ctx> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Self::MemoryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MemorySetter<'ctx>: MemoryOptic<'ctx> {
|
||||||
|
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, value: &Self::MemoryValue);
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::PointerValue,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
address::Address,
|
||||||
|
core::{MemoryGetter, MemoryOptic},
|
||||||
|
};
|
||||||
|
|
||||||
|
// ((Memory, Pointer) -> ElementOptic::Value*)
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GepGetter<ElementOptic> {
|
||||||
|
/// The LLVM GEP index
|
||||||
|
pub gep_index: u64,
|
||||||
|
/// Element (or field in the context of `struct`s) name. Used for cosmetics.
|
||||||
|
pub name: &'static str,
|
||||||
|
/// The lens to view the actual value after applying this [`FieldLens<T>`]
|
||||||
|
pub element_optic: ElementOptic,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, ElementOptic: MemoryOptic<'ctx>> MemoryOptic<'ctx> for GepGetter<ElementOptic> {
|
||||||
|
type MemoryValue = Address<'ctx, ElementOptic>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.element_optic.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, ElementOptic: MemoryOptic<'ctx>> MemoryGetter<'ctx> for GepGetter<ElementOptic> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Self::MemoryValue {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
|
||||||
|
let element_ptr = unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
pointer,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index, false)],
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
Address { address: element_ptr, addressee_optic: self.element_optic.clone() }
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{AnyValue, BasicValue, IntValue, PointerValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::core::{MemoryGetter, MemorySetter, MemoryOptic, Prism};
|
||||||
|
|
||||||
|
// NOTE: I wanted to make Int8Lens, Int16Lens, Int32Lens, with all
|
||||||
|
// having the trait IsIntLens, and implement `impl <S: IsIntLens> Optic<S> for T`,
|
||||||
|
// but that clashes with StructureOptic!!
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct IntLens<'ctx>(pub IntType<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> IntLens<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn int8(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i8_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn int32(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i32_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn int64(ctx: &'ctx Context) -> IntLens<'ctx> {
|
||||||
|
IntLens(ctx.i64_type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> MemoryOptic<'ctx> for IntLens<'ctx> {
|
||||||
|
type MemoryValue = IntValue<'ctx>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0.as_basic_type_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> Prism<'ctx> for IntLens<'ctx> {
|
||||||
|
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::MemoryValue {
|
||||||
|
let int = value.as_any_value_enum().into_int_value();
|
||||||
|
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||||
|
int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> MemoryGetter<'ctx> for IntLens<'ctx> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Self::MemoryValue {
|
||||||
|
self.review(ctx.builder.build_load(pointer, name).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> MemorySetter<'ctx> for IntLens<'ctx> {
|
||||||
|
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, int: &Self::MemoryValue) {
|
||||||
|
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||||
|
ctx.builder.build_store(pointer, int.as_basic_value_enum()).unwrap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
use inkwell::values::IntValue;
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
use super::address::Address;
|
||||||
|
|
||||||
|
// Name inspired by https://hackage.haskell.org/package/lens-5.3.2/docs/Control-Lens-At.html#t:Ixed
|
||||||
|
pub trait Ixed<'ctx, ElementOptic> {
|
||||||
|
// TODO: Interface/Method to expose the IntType of index?
|
||||||
|
// or even make index itself parameterized? (probably no)
|
||||||
|
|
||||||
|
fn ix(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Address<'ctx, ElementOptic>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Can do interface seggregation
|
||||||
|
pub trait BoundedIxed<'ctx, ElementOptic>: Ixed<'ctx, ElementOptic> {
|
||||||
|
fn num_elements(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx>;
|
||||||
|
|
||||||
|
// Check if 0 <= index < self.num_elements()
|
||||||
|
fn ix_bounds_checked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Address<'ctx, ElementOptic> {
|
||||||
|
let num_elements = self.num_elements(ctx);
|
||||||
|
let int_type = num_elements.get_type(); // NOTE: Weird get_type(), see comment under `trait Ixed`
|
||||||
|
|
||||||
|
assert_eq!(int_type.get_bit_width(), index.get_type().get_bit_width()); // Might as well check bit width to catch bugs
|
||||||
|
|
||||||
|
// TODO: SGE or UGE? or make it defined by the implementee?
|
||||||
|
|
||||||
|
// Check `0 <= index`
|
||||||
|
let lower_bounded = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
inkwell::IntPredicate::SLE,
|
||||||
|
int_type.const_zero(),
|
||||||
|
index,
|
||||||
|
"lower_bounded",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Check `index < num_elements`
|
||||||
|
let upper_bounded = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(inkwell::IntPredicate::SLT, index, num_elements, "upper_bounded")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Compute `0 <= index && index < num_elements`
|
||||||
|
let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap();
|
||||||
|
|
||||||
|
// Assert `bounded`
|
||||||
|
ctx.make_assert(generator, bounded, "0:IndexError", "nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}", [Some(index), Some(num_elements), None], ctx.current_loc);
|
||||||
|
|
||||||
|
// ...and finally do indexing
|
||||||
|
self.ix(ctx, index, name)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
pub mod address;
|
||||||
|
pub mod core;
|
||||||
|
pub mod gep;
|
||||||
|
pub mod int;
|
||||||
|
pub mod ixed;
|
||||||
|
pub mod slice;
|
||||||
|
pub mod structure;
|
||||||
|
|
||||||
|
pub use address::*;
|
||||||
|
pub use core::*;
|
||||||
|
pub use gep::*;
|
||||||
|
pub use int::*;
|
||||||
|
pub use ixed::*;
|
||||||
|
pub use slice::*;
|
||||||
|
pub use structure::*;
|
|
@ -0,0 +1,36 @@
|
||||||
|
use super::{
|
||||||
|
core::MemoryOptic,
|
||||||
|
ixed::{BoundedIxed, Ixed},
|
||||||
|
};
|
||||||
|
|
||||||
|
use inkwell::values::IntValue;
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::address::Address;
|
||||||
|
|
||||||
|
pub struct ArraySlice<'ctx, ElementOptic> {
|
||||||
|
pub num_elements: IntValue<'ctx>,
|
||||||
|
pub base: Address<'ctx, ElementOptic>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, ElementOptic: MemoryOptic<'ctx>> Ixed<'ctx, ElementOptic> for ArraySlice<'ctx, ElementOptic> {
|
||||||
|
fn ix(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Address<'ctx, ElementOptic> {
|
||||||
|
let element_addr =
|
||||||
|
unsafe { ctx.builder.build_in_bounds_gep(self.base.address, &[index], name).unwrap() };
|
||||||
|
Address { address: element_addr, addressee_optic: self.base.addressee_optic.clone() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, ElementOptic: MemoryOptic<'ctx>> BoundedIxed<'ctx, ElementOptic>
|
||||||
|
for ArraySlice<'ctx, ElementOptic>
|
||||||
|
{
|
||||||
|
fn num_elements(&self, _ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
self.num_elements
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,139 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::{BasicValue, BasicValueEnum, PointerValue, StructValue},
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
address::Address,
|
||||||
|
core::{MemoryGetter, MemorySetter, MemoryOptic, OpticValue},
|
||||||
|
gep::GepGetter,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub trait StructureOptic<'ctx>: Clone {
|
||||||
|
// Fields of optics
|
||||||
|
type Fields;
|
||||||
|
|
||||||
|
// TODO: Make it an associated function instead?
|
||||||
|
fn struct_name(&self) -> &'static str;
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields;
|
||||||
|
|
||||||
|
fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields {
|
||||||
|
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||||
|
self.build_fields(&mut builder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpticalStructValue<'ctx, StructOptic> {
|
||||||
|
optic: StructOptic,
|
||||||
|
llvm: StructValue<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, StructOptic> OpticValue<'ctx> for OpticalStructValue<'ctx, StructOptic> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.llvm.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check StructType
|
||||||
|
impl<'ctx, T: StructureOptic<'ctx>> MemoryOptic<'ctx> for T {
|
||||||
|
type MemoryValue = OpticalStructValue<'ctx, Self>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||||
|
self.build_fields(&mut builder); // Self::Fields is discarded
|
||||||
|
|
||||||
|
let field_types =
|
||||||
|
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
|
||||||
|
ctx.struct_type(&field_types, false).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T: StructureOptic<'ctx>> MemoryGetter<'ctx> for T {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Self::MemoryValue {
|
||||||
|
OpticalStructValue {
|
||||||
|
optic: self.clone(),
|
||||||
|
llvm: ctx.builder.build_load(pointer, name).unwrap().into_struct_value(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T: StructureOptic<'ctx>> MemorySetter<'ctx> for T {
|
||||||
|
fn set(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pointer: PointerValue<'ctx>,
|
||||||
|
value: &Self::MemoryValue,
|
||||||
|
) {
|
||||||
|
ctx.builder.build_store(pointer, value.llvm).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, AddresseeOptic: StructureOptic<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||||
|
pub fn focus<GetFieldGepFn, FieldElementOptic: MemoryOptic<'ctx>>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
get_field_gep_fn: GetFieldGepFn,
|
||||||
|
) -> Address<'ctx, FieldElementOptic>
|
||||||
|
where
|
||||||
|
GetFieldGepFn: FnOnce(&AddresseeOptic::Fields) -> &GepGetter<FieldElementOptic>,
|
||||||
|
{
|
||||||
|
let fields = self.addressee_optic.get_fields(ctx.ctx);
|
||||||
|
let field = get_field_gep_fn(&fields);
|
||||||
|
field.get(ctx, self.address, field.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only used by [`FieldBuilder`]
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct FieldInfo<'ctx> {
|
||||||
|
gep_index: u64,
|
||||||
|
name: &'ctx str,
|
||||||
|
llvm_type: BasicTypeEnum<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FieldBuilder<'ctx> {
|
||||||
|
pub ctx: &'ctx Context,
|
||||||
|
gep_index_counter: u64,
|
||||||
|
struct_name: &'ctx str,
|
||||||
|
fields: Vec<FieldInfo<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> FieldBuilder<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
|
||||||
|
FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_gep_index(&mut self) -> u64 {
|
||||||
|
let index = self.gep_index_counter;
|
||||||
|
self.gep_index_counter += 1;
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_field<ElementOptic: MemoryOptic<'ctx>>(
|
||||||
|
&mut self,
|
||||||
|
name: &'static str,
|
||||||
|
element_optic: ElementOptic,
|
||||||
|
) -> GepGetter<ElementOptic> {
|
||||||
|
let gep_index = self.next_gep_index();
|
||||||
|
|
||||||
|
self.fields.push(FieldInfo {
|
||||||
|
gep_index,
|
||||||
|
name,
|
||||||
|
llvm_type: element_optic.get_llvm_type(self.ctx),
|
||||||
|
});
|
||||||
|
|
||||||
|
GepGetter { gep_index, name, element_optic }
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,4 +23,4 @@ pub mod codegen;
|
||||||
pub mod symbol_resolver;
|
pub mod symbol_resolver;
|
||||||
pub mod toplevel;
|
pub mod toplevel;
|
||||||
pub mod typecheck;
|
pub mod typecheck;
|
||||||
pub mod util;
|
pub(crate) mod util;
|
||||||
|
|
|
@ -279,8 +279,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size_variant_to_int_type(variant: SizeVariant, primitives: &PrimitiveStore) -> Type {
|
fn get_size_variant_of_int(size_variant: SizeVariant, primitives: &PrimitiveStore) -> Type {
|
||||||
match variant {
|
match size_variant {
|
||||||
SizeVariant::Bits32 => primitives.int32,
|
SizeVariant::Bits32 => primitives.int32,
|
||||||
SizeVariant::Bits64 => primitives.int64,
|
SizeVariant::Bits64 => primitives.int64,
|
||||||
}
|
}
|
||||||
|
@ -502,7 +502,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
|
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim),
|
PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => {
|
||||||
|
self.build_np_max_min_function(prim)
|
||||||
|
}
|
||||||
|
|
||||||
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
|
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
|
||||||
self.build_np_minimum_maximum_function(prim)
|
self.build_np_minimum_maximum_function(prim)
|
||||||
|
@ -953,9 +955,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_copy(ctx, &obj, fun, &args, generator)
|
||||||
// gen_ndarray_copy(ctx, &obj, fun, &args, generator)
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
// .map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -971,9 +972,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
|
||||||
// gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
|
Ok(None)
|
||||||
// Ok(None)
|
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -1053,7 +1053,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
);
|
);
|
||||||
|
|
||||||
// The size variant of the function determines the size of the returned int.
|
// The size variant of the function determines the size of the returned int.
|
||||||
let int_sized = size_variant_to_int_type(size_variant, self.primitives);
|
let int_sized = get_size_variant_of_int(size_variant, self.primitives);
|
||||||
|
|
||||||
let ndarray_int_sized =
|
let ndarray_int_sized =
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
||||||
|
@ -1078,7 +1078,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
|
||||||
let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
|
let ret_elem_ty = get_size_variant_of_int(size_variant, &ctx.primitives);
|
||||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1119,7 +1119,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
||||||
|
|
||||||
// The size variant of the function determines the type of int returned
|
// The size variant of the function determines the type of int returned
|
||||||
let int_sized = size_variant_to_int_type(size_variant, self.primitives);
|
let int_sized = get_size_variant_of_int(size_variant, self.primitives);
|
||||||
let ndarray_int_sized =
|
let ndarray_int_sized =
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
||||||
|
|
||||||
|
@ -1142,7 +1142,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
|
||||||
let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
|
let ret_elem_ty = get_size_variant_of_int(size_variant, &ctx.primitives);
|
||||||
let func = match kind {
|
let func = match kind {
|
||||||
Kind::Ceil => builtin_fns::call_ceil,
|
Kind::Ceil => builtin_fns::call_ceil,
|
||||||
Kind::Floor => builtin_fns::call_floor,
|
Kind::Floor => builtin_fns::call_floor,
|
||||||
|
@ -1193,14 +1193,13 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
self.ndarray_float,
|
self.ndarray_float,
|
||||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
let func = match prim {
|
||||||
// let func = match prim {
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
||||||
// PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
||||||
// PrimDef::FunNpZeros => gen_ndarray_zeros,
|
PrimDef::FunNpOnes => gen_ndarray_ones,
|
||||||
// PrimDef::FunNpOnes => gen_ndarray_ones,
|
_ => unreachable!(),
|
||||||
// _ => unreachable!(),
|
};
|
||||||
// };
|
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||||
// func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1246,9 +1245,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_array(ctx, &obj, fun, &args, generator)
|
||||||
// gen_ndarray_array(ctx, &obj, fun, &args, generator)
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
// .map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -1266,9 +1264,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
// type variable
|
// type variable
|
||||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||||
// gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
// .map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1300,9 +1297,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_eye(ctx, &obj, fun, &args, generator)
|
||||||
// gen_ndarray_eye(ctx, &obj, fun, &args, generator)
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
// .map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -1315,9 +1311,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
self.ndarray_float_2d,
|
self.ndarray_float_2d,
|
||||||
&[(int32, "n")],
|
&[(int32, "n")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
todo!()
|
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
||||||
// gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
// .map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
|
@ -1554,39 +1549,45 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the functions `np_min()` and `np_max()`.
|
/// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()`
|
||||||
fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
/// Calls `call_numpy_max_min` with the function name
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]);
|
fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax],
|
||||||
|
);
|
||||||
|
|
||||||
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
|
let (var_map, ret_ty) = match prim {
|
||||||
let var_map = self
|
PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => {
|
||||||
.num_or_ndarray_var_map
|
(self.num_or_ndarray_var_map.clone(), self.primitives.int64)
|
||||||
.clone()
|
}
|
||||||
.into_iter()
|
PrimDef::FunNpMax | PrimDef::FunNpMin => {
|
||||||
.chain(once((ret_ty.id, ret_ty.ty)))
|
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
|
||||||
.collect::<IndexMap<_, _>>();
|
let var_map = self
|
||||||
|
.num_or_ndarray_var_map
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.chain(once((ret_ty.id, ret_ty.ty)))
|
||||||
|
.collect::<IndexMap<_, _>>();
|
||||||
|
(var_map, ret_ty.ty)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
&var_map,
|
&var_map,
|
||||||
prim.name(),
|
prim.name(),
|
||||||
ret_ty.ty,
|
ret_ty,
|
||||||
&[(self.float_or_ndarray_ty.ty, "a")],
|
&[(self.num_or_ndarray_ty.ty, "a")],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let a_ty = fun.0.args[0].ty;
|
let a_ty = fun.0.args[0].ty;
|
||||||
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
|
||||||
let func = match prim {
|
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?))
|
||||||
PrimDef::FunNpMin => builtin_fns::call_numpy_min,
|
|
||||||
PrimDef::FunNpMax => builtin_fns::call_numpy_max,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some(func(generator, ctx, (a_ty, a))?))
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the functions `np_minimum()` and `np_maximum()`.
|
/// Build the functions `np_minimum()` and `np_maximum()`.
|
||||||
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
|
||||||
|
|
|
@ -62,9 +62,11 @@ pub enum PrimDef {
|
||||||
FunMin,
|
FunMin,
|
||||||
FunNpMin,
|
FunNpMin,
|
||||||
FunNpMinimum,
|
FunNpMinimum,
|
||||||
|
FunNpArgmin,
|
||||||
FunMax,
|
FunMax,
|
||||||
FunNpMax,
|
FunNpMax,
|
||||||
FunNpMaximum,
|
FunNpMaximum,
|
||||||
|
FunNpArgmax,
|
||||||
FunAbs,
|
FunAbs,
|
||||||
FunNpIsNan,
|
FunNpIsNan,
|
||||||
FunNpIsInf,
|
FunNpIsInf,
|
||||||
|
@ -216,9 +218,11 @@ impl PrimDef {
|
||||||
PrimDef::FunMin => fun("min", None),
|
PrimDef::FunMin => fun("min", None),
|
||||||
PrimDef::FunNpMin => fun("np_min", None),
|
PrimDef::FunNpMin => fun("np_min", None),
|
||||||
PrimDef::FunNpMinimum => fun("np_minimum", None),
|
PrimDef::FunNpMinimum => fun("np_minimum", None),
|
||||||
|
PrimDef::FunNpArgmin => fun("np_argmin", None),
|
||||||
PrimDef::FunMax => fun("max", None),
|
PrimDef::FunMax => fun("max", None),
|
||||||
PrimDef::FunNpMax => fun("np_max", None),
|
PrimDef::FunNpMax => fun("np_max", None),
|
||||||
PrimDef::FunNpMaximum => fun("np_maximum", None),
|
PrimDef::FunNpMaximum => fun("np_maximum", None),
|
||||||
|
PrimDef::FunNpArgmax => fun("np_argmax", None),
|
||||||
PrimDef::FunAbs => fun("abs", None),
|
PrimDef::FunAbs => fun("abs", None),
|
||||||
PrimDef::FunNpIsNan => fun("np_isnan", None),
|
PrimDef::FunNpIsNan => fun("np_isnan", None),
|
||||||
PrimDef::FunNpIsInf => fun("np_isinf", None),
|
PrimDef::FunNpIsInf => fun("np_isinf", None),
|
||||||
|
|
|
@ -398,7 +398,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
if let Some(exc) = exc {
|
if let Some(exc) = exc {
|
||||||
self.virtual_checks.push((
|
self.virtual_checks.push((
|
||||||
exc.custom.unwrap(),
|
match &*self.unifier.get_ty(exc.custom.unwrap()) {
|
||||||
|
TypeEnum::TFunc(sign) => sign.ret,
|
||||||
|
_ => exc.custom.unwrap(),
|
||||||
|
},
|
||||||
self.primitives.exception,
|
self.primitives.exception,
|
||||||
exc.location,
|
exc.location,
|
||||||
));
|
));
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
/// A helper enum used by [`BuiltinBuilder`]
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum SizeVariant {
|
||||||
|
Bits32,
|
||||||
|
Bits64,
|
||||||
|
}
|
|
@ -183,8 +183,10 @@ def patch(module):
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
module.np_min = np.min
|
module.np_min = np.min
|
||||||
module.np_minimum = np.minimum
|
module.np_minimum = np.minimum
|
||||||
|
module.np_argmin = np.argmin
|
||||||
module.np_max = np.max
|
module.np_max = np.max
|
||||||
module.np_maximum = np.maximum
|
module.np_maximum = np.maximum
|
||||||
|
module.np_argmax = np.argmax
|
||||||
module.np_sin = np.sin
|
module.np_sin = np.sin
|
||||||
module.np_cos = np.cos
|
module.np_cos = np.cos
|
||||||
module.np_exp = np.exp
|
module.np_exp = np.exp
|
||||||
|
|
|
@ -867,6 +867,13 @@ def test_ndarray_minimum_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(min_x_zeros)
|
output_ndarray_float_2(min_x_zeros)
|
||||||
output_ndarray_float_2(min_x_ones)
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_argmin():
|
||||||
|
x = np_array([[1., 2.], [3., 4.]])
|
||||||
|
y = np_argmin(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_int64(y)
|
||||||
|
|
||||||
def test_ndarray_max():
|
def test_ndarray_max():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = np_max(x)
|
y = np_max(x)
|
||||||
|
@ -910,6 +917,13 @@ def test_ndarray_maximum_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(max_x_zeros)
|
output_ndarray_float_2(max_x_zeros)
|
||||||
output_ndarray_float_2(max_x_ones)
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_argmax():
|
||||||
|
x = np_array([[1., 2.], [3., 4.]])
|
||||||
|
y = np_argmax(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_int64(y)
|
||||||
|
|
||||||
def test_ndarray_abs():
|
def test_ndarray_abs():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = abs(x)
|
y = abs(x)
|
||||||
|
@ -1524,11 +1538,13 @@ def run() -> int32:
|
||||||
test_ndarray_minimum_broadcast()
|
test_ndarray_minimum_broadcast()
|
||||||
test_ndarray_minimum_broadcast_lhs_scalar()
|
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||||
test_ndarray_minimum_broadcast_rhs_scalar()
|
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmin()
|
||||||
test_ndarray_max()
|
test_ndarray_max()
|
||||||
test_ndarray_maximum()
|
test_ndarray_maximum()
|
||||||
test_ndarray_maximum_broadcast()
|
test_ndarray_maximum_broadcast()
|
||||||
test_ndarray_maximum_broadcast_lhs_scalar()
|
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||||
test_ndarray_maximum_broadcast_rhs_scalar()
|
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmax()
|
||||||
test_ndarray_abs()
|
test_ndarray_abs()
|
||||||
test_ndarray_isnan()
|
test_ndarray_isnan()
|
||||||
test_ndarray_isinf()
|
test_ndarray_isinf()
|
||||||
|
|
|
@ -81,7 +81,6 @@ in rec {
|
||||||
''
|
''
|
||||||
mkdir -p $out/bin
|
mkdir -p $out/bin
|
||||||
ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt.exe
|
ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt.exe
|
||||||
ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt-test.exe
|
|
||||||
ln -s ${llvm-nac3}/bin/llvm-as.exe $out/bin/llvm-as-irrt.exe
|
ln -s ${llvm-nac3}/bin/llvm-as.exe $out/bin/llvm-as-irrt.exe
|
||||||
'';
|
'';
|
||||||
nac3artiq = pkgs.rustPlatform.buildRustPackage {
|
nac3artiq = pkgs.rustPlatform.buildRustPackage {
|
||||||
|
|
Loading…
Reference in New Issue