From 43e9a9539d2ff080aab313e06598a280349a6cf9 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 10 Jul 2024 00:38:01 +0800 Subject: [PATCH] WIP --- compile_irrt.sh | 3 + flake.nix | 5 +- nac3core/build.rs | 2 + nac3core/irrt/irrt.cpp | 6 +- nac3core/irrt/{irrt.hpp => irrt_basic.hpp} | 244 +------- nac3core/irrt/irrt_everything.hpp | 11 + nac3core/irrt/irrt_numpy_ndarray.hpp | 196 +++++++ nac3core/irrt/irrt_test.cpp | 171 +++++- nac3core/irrt/irrt_typedefs.hpp | 12 + nac3core/irrt/irrt_utils.hpp | 11 + nac3core/src/codegen/builtin_fns.rs | 162 +++--- nac3core/src/codegen/classes.rs | 336 ++++++++++- nac3core/src/codegen/irrt/mod.rs | 419 ++------------ nac3core/src/codegen/numpy.rs | 641 ++++++++++++--------- nac3core/src/lib.rs | 1 + nac3core/src/toplevel/builtins.rs | 26 +- nac3core/src/toplevel/mod.rs | 1 + nac3core/src/util.rs | 5 + 18 files changed, 1263 insertions(+), 989 deletions(-) create mode 100755 compile_irrt.sh rename nac3core/irrt/{irrt.hpp => irrt_basic.hpp} (51%) create mode 100644 nac3core/irrt/irrt_everything.hpp create mode 100644 nac3core/irrt/irrt_numpy_ndarray.hpp create mode 100644 nac3core/irrt/irrt_typedefs.hpp create mode 100644 nac3core/irrt/irrt_utils.hpp create mode 100644 nac3core/src/util.rs diff --git a/compile_irrt.sh b/compile_irrt.sh new file mode 100755 index 00000000..59ea7fe0 --- /dev/null +++ b/compile_irrt.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +clang-irrt --target=wasm32 -x c++ -fno-discard-value-names -fno-exceptions -fno-rtti -O0 -emit-llvm -S -Wall -Wextra nac3core/irrt/irrt.cpp +clang -x c++ -fno-discard-value-names -fno-exceptions -fno-rtti -O0 -emit-llvm -S -Wall -Wextra nac3core/irrt/irrt_test.cpp diff --git a/flake.nix b/flake.nix index a6ce5fce..49df0dd0 100644 --- a/flake.nix +++ b/flake.nix @@ -41,7 +41,7 @@ ''; installPhase = '' - PYTHON_SITEPACKAGES=$out/${pkgs.python3Packages.python.sitePackages} +u PYTHON_SITEPACKAGES=$out/${pkgs.python3Packages.python.sitePackages} mkdir -p $PYTHON_SITEPACKAGES cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $PYTHON_SITEPACKAGES/nac3artiq.so @@ -163,7 +163,10 @@ clippy pre-commit rustfmt + rust-analyzer ]; + # https://nixos.wiki/wiki/Rust#Shell.nix_example + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3core/build.rs b/nac3core/build.rs index 3657ad08..123b2277 100644 --- a/nac3core/build.rs +++ b/nac3core/build.rs @@ -31,6 +31,7 @@ fn compile_irrt(irrt_dir: &Path, out_dir: &Path) { "-S", "-Wall", "-Wextra", + "-Werror=return-type", "-I", irrt_dir.to_str().unwrap(), "-o", @@ -100,6 +101,7 @@ fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) { "-O0", "-Wall", "-Wextra", + "-Werror=return-type", "-lm", // for `tgamma()`, `lgamma()` "-o", exe_path.to_str().unwrap(), diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 4cde95b7..d92b497c 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,3 +1,5 @@ -#include "irrt.hpp" +#include "irrt_everything.hpp" -// All the implementations are from `irrt.hpp` +/* + This file will be read by `clang-irrt` to conveniently produce LLVM IR for `nac3core/codegen`. +*/ diff --git a/nac3core/irrt/irrt.hpp b/nac3core/irrt/irrt_basic.hpp similarity index 51% rename from nac3core/irrt/irrt.hpp rename to nac3core/irrt/irrt_basic.hpp index 5327f685..8935cdb8 100644 --- a/nac3core/irrt/irrt.hpp +++ b/nac3core/irrt/irrt_basic.hpp @@ -1,28 +1,19 @@ -#ifndef IRRT_DONT_TYPEDEF_INTS -typedef _BitInt(8) int8_t; -typedef unsigned _BitInt(8) uint8_t; -typedef _BitInt(32) int32_t; -typedef unsigned _BitInt(32) uint32_t; -typedef _BitInt(64) int64_t; -typedef unsigned _BitInt(64) uint64_t; -#endif +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" + +/* + This header contains IRRT implementations + that do not deserved to be categorized (e.g., into numpy, etc.) + + Check out other *.hpp files before including them here!! +*/ -// NDArray indices are always `uint32_t`. -typedef uint32_t NDIndex; // The type of an index or a value describing the length of a range/slice is // always `int32_t`. typedef int32_t SliceIndex; -template -static T max(T a, T b) { - return a > b ? a : b; -} - -template -static T min(T a, T b) { - return a > b ? b : a; -} - // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function template @@ -39,119 +30,6 @@ static T __nac3_int_exp_impl(T base, T exp) { return res; } -template -static SizeT __nac3_ndarray_calc_size_impl( - const SizeT *list_data, - SizeT list_len, - SizeT begin_idx, - SizeT end_idx -) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -static void __nac3_ndarray_calc_nd_indices_impl( - SizeT index, - const SizeT *dims, - SizeT num_dims, - NDIndex *idxs -) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} - -template -static SizeT __nac3_ndarray_flatten_index_impl( - const SizeT *dims, - SizeT num_dims, - const NDIndex *indices, - SizeT num_indices -) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - -template -static void __nac3_ndarray_calc_broadcast_impl( - const SizeT *lhs_dims, - SizeT lhs_ndims, - const SizeT *rhs_dims, - SizeT rhs_ndims, - SizeT *out_dims -) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT *out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -static void __nac3_ndarray_calc_broadcast_idx_impl( - const SizeT *src_dims, - SizeT src_ndims, - const NDIndex *in_idx, - NDIndex *out_idx -) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} - -template -static void __nac3_ndarray_strides_from_shape_impl( - SizeT ndims, - SizeT *shape, - SizeT *dst_strides -) { - SizeT stride_product = 1; - for (SizeT i = 0; i < ndims; i++) { - int dim_i = ndims - i - 1; - dst_strides[dim_i] = stride_product; - stride_product *= shape[dim_i]; - } -} - extern "C" { #define DEF_nac3_int_exp_(T) \ T __nac3_int_exp_##T(T base, T exp) {\ @@ -334,104 +212,4 @@ extern "C" { return j0(x); } - - uint32_t __nac3_ndarray_calc_size( - const uint32_t *list_data, - uint32_t list_len, - uint32_t begin_idx, - uint32_t end_idx - ) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); - } - - uint64_t __nac3_ndarray_calc_size64( - const uint64_t *list_data, - uint64_t list_len, - uint64_t begin_idx, - uint64_t end_idx - ) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); - } - - void __nac3_ndarray_calc_nd_indices( - uint32_t index, - const uint32_t* dims, - uint32_t num_dims, - NDIndex* idxs - ) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); - } - - void __nac3_ndarray_calc_nd_indices64( - uint64_t index, - const uint64_t* dims, - uint64_t num_dims, - NDIndex* idxs - ) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); - } - - uint32_t __nac3_ndarray_flatten_index( - const uint32_t* dims, - uint32_t num_dims, - const NDIndex* indices, - uint32_t num_indices - ) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); - } - - uint64_t __nac3_ndarray_flatten_index64( - const uint64_t* dims, - uint64_t num_dims, - const NDIndex* indices, - uint64_t num_indices - ) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); - } - - void __nac3_ndarray_calc_broadcast( - const uint32_t *lhs_dims, - uint32_t lhs_ndims, - const uint32_t *rhs_dims, - uint32_t rhs_ndims, - uint32_t *out_dims - ) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); - } - - void __nac3_ndarray_calc_broadcast64( - const uint64_t *lhs_dims, - uint64_t lhs_ndims, - const uint64_t *rhs_dims, - uint64_t rhs_ndims, - uint64_t *out_dims - ) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); - } - - void __nac3_ndarray_calc_broadcast_idx( - const uint32_t *src_dims, - uint32_t src_ndims, - const NDIndex *in_idx, - NDIndex *out_idx - ) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); - } - - void __nac3_ndarray_calc_broadcast_idx64( - const uint64_t *src_dims, - uint64_t src_ndims, - const NDIndex *in_idx, - NDIndex *out_idx - ) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); - } - - void __nac3_ndarray_strides_from_shape(uint32_t ndims, uint32_t* shape, uint32_t* dst_strides) { - __nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides); - } - - void __nac3_ndarray_strides_from_shape64(uint64_t ndims, uint64_t* shape, uint64_t* dst_strides) { - __nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides); - } } \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp new file mode 100644 index 00000000..69ca0039 --- /dev/null +++ b/nac3core/irrt/irrt_everything.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "irrt_basic.hpp" +#include "irrt_numpy_ndarray.hpp" + +/* + All IRRT implementations. + + We don't have any pre-compiled objects, so we are writing all implementations in headers and + concatenate them with `#include` into one massive source file that contains all the IRRT stuff. +*/ \ No newline at end of file diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp new file mode 100644 index 00000000..9b1b6836 --- /dev/null +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -0,0 +1,196 @@ +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" + +/* + NDArray-related implementations. +`*/ + +// NDArray indices are always `uint32_t`. +using NDIndex = uint32_t; + +namespace { +namespace ndarray_util { + // Compute the strides of an ndarray given an ndarray `shape` + // and assuming that the ndarray is *fully C-contagious*. + // + // You might want to read up on https://ajcr.net/stride-guide-part-1/. + template + static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndims; i++) { + int dim_i = ndims - i - 1; + dst_strides[dim_i] = stride_product; + stride_product *= shape[dim_i]; + } + } + + // Compute the size/# of elements of an ndarray given its shape + template + static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i]; + return size; + } +} + +template +struct NDArrayIndicesIter { + SizeT ndims; + const SizeT *shape; + SizeT *indices; + + void set_indices_zero() { + __builtin_memset(indices, 0, sizeof(SizeT) * ndims); + } + + void next() { + for (SizeT i = 0; i < ndims; i++) { + SizeT dim_i = ndims - i - 1; + + indices[dim_i]++; + if (indices[dim_i] < shape[dim_i]) { + break; + } else { + indices[dim_i] = 0; + } + } + } +}; + +// The NDArray object. `SizeT` is the *signed* size type of this ndarray. +// +// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT +// +// Some resources you might find helpful: +// - The official numpy implementations: +// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst +// - On strides (about reshaping, slicing, C-contagiousness, etc) +// - https://ajcr.net/stride-guide-part-1/. +// - https://ajcr.net/stride-guide-part-2/. +// - https://ajcr.net/stride-guide-part-3/. +template +struct NDArray { + // The underlying data this `ndarray` is pointing to. + // + // NOTE: Formally this should be of type `void *`, but clang + // translates `void *` to `i8 *` when run with `-S -emit-llvm`, + // so we will put `uint8_t *` here for clarity. + uint8_t *data; + + // The number of bytes of a single element in `data`. + // + // The `SizeT` is treated as `unsigned`. + SizeT itemsize; + + // The number of dimensions of this shape. + // + // The `SizeT` is treated as `unsigned`. + SizeT ndims; + + // Array shape, with length equal to `ndims`. + // + // The `SizeT` is treated as `unsigned`. + // + // NOTE: `shape` can contain 0. + // (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`) + SizeT *shape; + + // Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`. + // + // The `SizeT` is treated as `signed`. + // + // NOTE: `strides` can have negative numbers. + // (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`) + SizeT *strides; + + // Calculate the size/# of elements of an `ndarray`. + // This function corresponds to `np.size()` or `ndarray.size` + SizeT size() { + return ndarray_util::calc_size_from_shape(ndims, shape); + } + + // Calculate the number of bytes of its content of an `ndarray` *in its view*. + // This function corresponds to `ndarray.nbytes` + SizeT nbytes() { + return this->size() * itemsize; + } + + void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) { + __builtin_memcpy(pelement, pvalue, itemsize); + } + + uint8_t* get_pelement(SizeT *indices) { + uint8_t* element = data; + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) + element += indices[dim_i] * strides[dim_i] * itemsize; + return element; + } + + // Is the given `indices` valid/in-bounds? + bool in_bounds(SizeT *indices) { + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) { + bool dim_ok = indices[dim_i] < shape[dim_i]; + if (!dim_ok) return false; + } + return true; + } + + // Fill the ndarray with a value + void fill_generic(uint8_t* pvalue) { + NDArrayIndicesIter iter; + iter.ndims = this->ndims; + iter.shape = this->shape; + iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims); + iter.set_indices_zero(); + + for (SizeT i = 0; i < this->size(); i++, iter.next()) { + uint8_t* pelement = get_pelement(iter.indices); + set_value_at_pelement(pelement, pvalue); + } + } + + // Set the strides of the ndarray with `ndarray_util::set_strides_by_shape` + void set_strides_by_shape() { + ndarray_util::set_strides_by_shape(ndims, strides, shape); + } + + // https://numpy.org/doc/stable/reference/generated/numpy.eye.html + void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) { + __builtin_assume(ndims == 2); + + // TODO: Better implementation + + fill_generic(zero_pvalue); + for (SizeT i = 0; i < min(shape[0], shape[1]); i++) { + SizeT row = i; + SizeT col = i + k; + SizeT indices[2] = { row, col }; + + if (!in_bounds(indices)) continue; + + uint8_t* pelement = get_pelement(indices); + set_value_at_pelement(pelement, one_pvalue); + } + } +}; +} + +extern "C" { + uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return ndarray->size(); + } + + uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return ndarray->size(); + } + + void __nac3_ndarray_fill_generic(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); + } + + void __nac3_ndarray_fill_generic64(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index edd3f5e6..c1efc7cb 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -2,17 +2,21 @@ #include #include +// set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` has it all #define IRRT_DONT_TYPEDEF_INTS -#include "irrt.hpp" +#include "irrt_everything.hpp" -static void __test_fail(const char *file, int line) { - // NOTE: Try to make the location info follow a format that - // VSCode/other IDEs would recognize as a clickable URL. - printf("[!] test_fail() invoked at %s:%d", file, line); +namespace { +static void test_fail() { + printf("[!] Test failed\n"); exit(1); } -#define test_fail() __test_fail(__FILE__, __LINE__); +static void __begin_test(const char* function_name, const char* file, int line) { + printf("######### Running %s @ %s:%d\n", function_name, file, line); +} + +#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__) template bool arrays_match(int len, T *as, T *bs) { @@ -23,40 +27,163 @@ bool arrays_match(int len, T *as, T *bs) { } template -void debug_print_array(const char* format, int len, T *as) { +void debug_print_array(const char* format, int len, T* as) { printf("["); for (int i = 0; i < len; i++) { if (i != 0) printf(", "); printf(format, as[i]); } - printf("]\n"); + printf("]"); } template -bool assert_arrays_match(const char *label, const char *format, int len, T *expected, T *got) { - auto match = arrays_match(len, expected, got); - - if (!match) { +void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) { + if (!arrays_match(len, expected, got)) { printf("expected %s: ", label); debug_print_array(format, len, expected); + printf("\n"); printf("got %s: ", label); debug_print_array(format, len, got); + printf("\n"); + test_fail(); } - - return match; } -static void test_strides_from_shape() { - const uint64_t ndims = 4; - uint64_t shape[ndims] = { 999, 3, 5, 7 }; - uint64_t strides[ndims] = { 0 }; - __nac3_ndarray_strides_from_shape64(ndims, shape, strides); +template +void assert_values_match(const char* label, const char* format, T expected, T got) { + if (expected != got) { + printf("expected %s: ", label); + printf(format, expected); + printf("\n"); + printf("got %s: ", label); + printf(format, got); + printf("\n"); + test_fail(); + } +} - uint64_t expected_strides[ndims] = { 3*5*7, 5*7, 7, 1 }; - if (!assert_arrays_match("strides", "%u", ndims, expected_strides, strides)) test_fail(); +void test_calc_size_from_shape_normal() { + // Test shapes with normal values + BEGIN_TEST(); + + int32_t shape[4] = { 2, 3, 5, 7 }; + debug_print_array("%d", 4, shape); + assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape(4, shape)); +} + +void test_calc_size_from_shape_has_zero() { + // Test shapes with 0 in them + BEGIN_TEST(); + + int32_t shape[4] = { 2, 0, 5, 7 }; + assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape(4, shape)); +} + +void test_set_strides_by_shape() { + // Test `set_strides_by_shape()` + BEGIN_TEST(); + + int32_t shape[4] = { 99, 3, 5, 7 }; + int32_t strides[4] = { 0 }; + ndarray_util::set_strides_by_shape(4, strides, shape); + + int32_t expected_strides[4] = { 105, 35, 7, 1 }; + assert_arrays_match("strides", "%u", 4u, expected_strides, strides); +} + +void test_ndarray_indices_iter_normal() { + // Test NDArrayIndicesIter normal behavior + BEGIN_TEST(); + + int32_t shape[3] = { 1, 2, 3 }; + int32_t indices[3] = { 0, 0, 0 }; + auto iter = NDArrayIndicesIter { + .ndims = 3u, + .shape = shape, + .indices = indices + }; + + assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); + iter.next(); + assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 }); + iter.next(); + assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 }); + iter.next(); + assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 }); + iter.next(); + assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 }); + iter.next(); + assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 }); + iter.next(); + assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back + iter.next(); + assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 }); +} + +void test_ndarray_fill_generic() { + // Test ndarray fill_generic + BEGIN_TEST(); + + // Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up + // Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly. + uint16_t fill_value = 0xFACE; + + uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999` + int32_t in_itemsize = sizeof(uint16_t); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 2, 3 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides, + }; + ndarray.set_strides_by_shape(); + ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here + + uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value }; + assert_arrays_match("data", "0x%hX", 6, expected_data, in_data); +} + +void test_ndarray_set_to_eye() { + double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 }; + int32_t in_itemsize = sizeof(double); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 3, 3 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides, + }; + ndarray.set_strides_by_shape(); + + double zero = 0.0; + double one = 1.0; + ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one); + + assert_values_match("in_data[0]", "%f", 0.0, in_data[0]); + assert_values_match("in_data[1]", "%f", 1.0, in_data[1]); + assert_values_match("in_data[2]", "%f", 0.0, in_data[2]); + assert_values_match("in_data[3]", "%f", 0.0, in_data[3]); + assert_values_match("in_data[4]", "%f", 0.0, in_data[4]); + assert_values_match("in_data[5]", "%f", 1.0, in_data[5]); + assert_values_match("in_data[6]", "%f", 0.0, in_data[6]); + assert_values_match("in_data[7]", "%f", 0.0, in_data[7]); + assert_values_match("in_data[8]", "%f", 0.0, in_data[8]); +} } int main() { - test_strides_from_shape(); + test_calc_size_from_shape_normal(); + test_calc_size_from_shape_has_zero(); + test_set_strides_by_shape(); + test_ndarray_indices_iter_normal(); + test_ndarray_fill_generic(); + test_ndarray_set_to_eye(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/irrt_typedefs.hpp b/nac3core/irrt/irrt_typedefs.hpp new file mode 100644 index 00000000..acd75da8 --- /dev/null +++ b/nac3core/irrt/irrt_typedefs.hpp @@ -0,0 +1,12 @@ +#pragma once + +// This is made toggleable since `irrt_test.cpp` itself would include +// headers that define the `int_t` family. +#ifndef IRRT_DONT_TYPEDEF_INTS +typedef _BitInt(8) int8_t; +typedef unsigned _BitInt(8) uint8_t; +typedef _BitInt(32) int32_t; +typedef unsigned _BitInt(32) uint32_t; +typedef _BitInt(64) int64_t; +typedef unsigned _BitInt(64) uint64_t; +#endif \ No newline at end of file diff --git a/nac3core/irrt/irrt_utils.hpp b/nac3core/irrt/irrt_utils.hpp new file mode 100644 index 00000000..c4630b22 --- /dev/null +++ b/nac3core/irrt/irrt_utils.hpp @@ -0,0 +1,11 @@ +#pragma once + +template +static T max(T a, T b) { + return a > b ? a : b; +} + +template +static T min(T a, T b) { + return a > b ? b : a; +} \ No newline at end of file diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index f271b457..6bb3f843 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -702,53 +702,54 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + todo!() + // let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + // let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); + // let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + // let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + // if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + // let n_sz_eqz = ctx + // .builder + // .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + // .unwrap(); - ctx.make_assert( - generator, - n_sz_eqz, - "0:ValueError", - "zero-size array to reduction operation minimum which has no identity", - [None, None, None], - ctx.current_loc, - ); - } + // ctx.make_assert( + // generator, + // n_sz_eqz, + // "0:ValueError", + // "zero-size array to reduction operation minimum which has no identity", + // [None, None, None], + // ctx.current_loc, + // ); + // } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - } + // let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + // unsafe { + // let identity = + // n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + // ctx.builder.build_store(accumulator_addr, identity).unwrap(); + // } - gen_for_callback_incrementing( - generator, - ctx, - llvm_usize.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; + // gen_for_callback_incrementing( + // generator, + // ctx, + // llvm_usize.const_int(1, false), + // (n_sz, false), + // |generator, ctx, _, idx| { + // let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + // let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); + // ctx.builder.build_store(accumulator_addr, result).unwrap(); - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + // Ok(()) + // }, + // llvm_usize.const_int(1, false), + // )?; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - accumulator + // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + // accumulator } _ => unsupported_type(ctx, FN_NAME, &[a_ty]), @@ -920,53 +921,54 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + todo!() + // let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + // let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); + // let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + // let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + // if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + // let n_sz_eqz = ctx + // .builder + // .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + // .unwrap(); - ctx.make_assert( - generator, - n_sz_eqz, - "0:ValueError", - "zero-size array to reduction operation minimum which has no identity", - [None, None, None], - ctx.current_loc, - ); - } + // ctx.make_assert( + // generator, + // n_sz_eqz, + // "0:ValueError", + // "zero-size array to reduction operation minimum which has no identity", + // [None, None, None], + // ctx.current_loc, + // ); + // } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - } + // let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + // unsafe { + // let identity = + // n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + // ctx.builder.build_store(accumulator_addr, identity).unwrap(); + // } - gen_for_callback_incrementing( - generator, - ctx, - llvm_usize.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; + // gen_for_callback_incrementing( + // generator, + // ctx, + // llvm_usize.const_int(1, false), + // (n_sz, false), + // |generator, ctx, _, idx| { + // let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + // let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); + // ctx.builder.build_store(accumulator_addr, result).unwrap(); - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + // Ok(()) + // }, + // llvm_usize.const_int(1, false), + // )?; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - accumulator + // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + // accumulator } _ => unsupported_type(ctx, FN_NAME, &[a_ty]), diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index d39b55ca..5bcf05ad 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,8 +1,6 @@ use crate::codegen::{ - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, - llvm_intrinsics::call_int_umin, - stmt::gen_for_callback_incrementing, - CodeGenContext, CodeGenerator, + llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, CodeGenContext, + CodeGenerator, }; use inkwell::context::Context; use inkwell::types::{ArrayType, BasicType, StructType}; @@ -12,6 +10,7 @@ use inkwell::{ values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; /// A LLVM type that is used to represent a non-primitive type in NAC3. pub trait ProxyType<'ctx>: Into { @@ -1601,7 +1600,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> IntValue<'ctx> { - call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) + todo!() + // call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) } } @@ -1675,17 +1675,19 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices_elem_ty.get_bit_width() ); - let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + todo!() - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - } + // let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + + // unsafe { + // ctx.builder + // .build_in_bounds_gep( + // self.base_ptr(ctx, generator), + // &[index], + // name.unwrap_or_default(), + // ) + // .unwrap() + // } } fn ptr_offset( @@ -1761,3 +1763,307 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, for NDArrayDataProxy<'ctx, '_> { } + +#[derive(Debug, Clone, Copy)] +pub struct StructField<'ctx> { + /// The GEP index of this struct field. + pub gep_index: u32, + /// Name of this struct field. + /// + /// Used for generating names. + pub name: &'static str, + /// The type of this struct field. + pub ty: BasicTypeEnum<'ctx>, +} + +pub struct StructFields<'ctx> { + /// Name of the struct. + /// + /// Used for generating names. + pub name: &'static str, + + /// All the [`StructField`]s of this struct. + /// + /// **NOTE:** The index position of a [`StructField`] + /// matches the element's [`StructField::index`]. + pub fields: Vec>, +} + +struct StructFieldsBuilder<'ctx> { + gep_index_counter: u32, + /// Name of the struct to be built. + name: &'static str, + fields: Vec>, +} + +impl<'ctx> StructField<'ctx> { + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + ) -> PointerValue<'ctx> { + ctx.builder.build_struct_gep(ptr, self.gep_index, self.name).unwrap() + } + + pub fn load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + ctx.builder.build_load(self.gep(ctx, ptr), self.name).unwrap() + } + + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>, value: V) + where + V: BasicValue<'ctx>, + { + ctx.builder.build_store(ptr, value).unwrap(); + } +} + +type IsInstanceError = String; +type IsInstanceResult = Result<(), IsInstanceError>; + +pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult +where + A: BasicType<'ctx>, + B: BasicType<'ctx>, +{ + let expected = expected.as_basic_type_enum(); + let got = got.as_basic_type_enum(); + + // Put those logic into here, + // otherwise there is always a fallback reporting on any kind of mismatch + match (expected, got) { + (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { + if expected.get_bit_width() != got.get_bit_width() { + return Err(format!( + "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" + )); + } + } + (expected, got) => { + if expected != got { + return Err(format!("Expected {expected}, got {got}")); + } + } + } + Ok(()) +} + +impl<'ctx> StructFields<'ctx> { + pub fn num_fields(&self) -> u32 { + self.fields.len() as u32 + } + + pub fn as_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec(); + ctx.struct_type(llvm_fields.as_slice(), false) + } + + pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult { + // Check scrutinee's number of struct fields + if scrutinee.count_fields() != self.num_fields() { + return Err(format!( + "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}", + struct_name = self.name, + expected_count = self.num_fields(), + got_count = scrutinee.count_fields(), + )); + } + + // Check the scrutinee's field types + for field in self.fields.iter() { + let expected_field_ty = field.ty; + let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap(); + + if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) { + return Err(format!( + "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}", + gep_index = field.gep_index, + struct_name = self.name, + field_name = field.name, + )); + } + } + + // Done + Ok(()) + } +} + +impl<'ctx> StructFieldsBuilder<'ctx> { + fn start(name: &'static str) -> Self { + StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() } + } + + fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + StructField { gep_index: index, name, ty } + } + + fn end(self) -> StructFields<'ctx> { + StructFields { name: self.name, fields: self.fields } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayType<'ctx> { + pub size_type: IntType<'ctx>, + pub elem_type: BasicTypeEnum<'ctx>, +} + +pub struct NpArrayStructFields<'ctx> { + pub whole_struct: StructFields<'ctx>, + pub data: StructField<'ctx>, + pub itemsize: StructField<'ctx>, + pub ndims: StructField<'ctx>, + pub shape: StructField<'ctx>, + pub strides: StructField<'ctx>, +} + +impl<'ctx> NpArrayType<'ctx> { + pub fn new_opaque_elem( + ctx: &CodeGenContext<'ctx, '_>, + size_type: IntType<'ctx>, + ) -> NpArrayType<'ctx> { + NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() } + } + + pub fn struct_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructType<'ctx> { + self.fields().whole_struct.as_struct_type(ctx.ctx) + } + + pub fn fields(&self) -> NpArrayStructFields<'ctx> { + let mut builder = StructFieldsBuilder::start("NpArray"); + + let addrspace = AddressSpace::default(); + + let byte_type = self.size_type.get_context().i8_type(); + + // Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`. + let data = builder.add_field("data", byte_type.ptr_type(addrspace).into()); + let itemsize = builder.add_field("itemsize", self.size_type.into()); + let ndims = builder.add_field("ndims", self.size_type.into()); + let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into()); + let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into()); + + NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides } + } + + /// Allocate an `ndarray` on stack, with the following notes: + /// + /// - `ndarray.ndims` will be initialized to `in_ndims`. + /// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`. + /// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`, + /// all with empty/uninitialized values. + pub fn alloca( + &self, + ctx: &CodeGenContext<'ctx, '_>, + in_ndims: IntValue<'ctx>, + name: &str, + ) -> NpArrayValue<'ctx> { + let fields = self.fields(); + let ptr = + ctx.builder.build_alloca(fields.whole_struct.as_struct_type(ctx.ctx), name).unwrap(); + + // Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides` + let allocated_shape = + ctx.builder.build_array_alloca(fields.shape.ty, in_ndims, "allocated_shape").unwrap(); + let allocated_strides = ctx + .builder + .build_array_alloca(fields.strides.ty, in_ndims, "allocated_strides") + .unwrap(); + + let value = NpArrayValue { ty: *self, ptr }; + value.store_ndims(ctx, in_ndims); + value.store_itemsize(ctx, self.elem_type.size_of().unwrap()); + value.store_shape(ctx, allocated_shape); + value.store_strides(ctx, allocated_strides); + + return value; + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayValue<'ctx> { + pub ty: NpArrayType<'ctx>, + pub ptr: PointerValue<'ctx>, +} + +impl<'ctx> NpArrayValue<'ctx> { + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let field = self.ty.fields().ndims; + field.load(ctx, self.ptr).into_int_value() + } + + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + let field = self.ty.fields().ndims; + field.store(ctx, self.ptr, value); + } + + pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let field = self.ty.fields().itemsize; + field.load(ctx, self.ptr).into_int_value() + } + + pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + let field = self.ty.fields().itemsize; + field.store(ctx, self.ptr, value); + } + + pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let field = self.ty.fields().shape; + field.load(ctx, self.ptr).into_pointer_value() + } + + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + let field = self.ty.fields().shape; + field.store(ctx, self.ptr, value); + } + + pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let field = self.ty.fields().strides; + field.load(ctx, self.ptr).into_pointer_value() + } + + pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + let field = self.ty.fields().strides; + field.store(ctx, self.ptr, value); + } + + /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! + pub fn shape_slice( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let field = self.ty.fields().shape; + field.gep(ctx, self.ptr); + + let ndims = self.load_ndims(ctx); + + TypedArrayLikeAdapter { + adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)), + downcast_fn: Box::new(|_ctx, x| x.into_int_value()), + upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), + } + } + + /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! + pub fn strides_slice( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let field = self.ty.fields().strides; + field.gep(ctx, self.ptr); + + let ndims = self.load_ndims(ctx); + + TypedArrayLikeAdapter { + adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)), + downcast_fn: Box::new(|_ctx, x| x.into_int_value()), + upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), + } + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index c98566f9..06cf979a 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,10 +1,14 @@ -use crate::typecheck::typedef::Type; +use crate::{ + codegen::classes::{NDArrayType, NpArrayType}, + typecheck::typedef::Type, + util::SizeVariant, +}; mod test; use super::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, llvm_intrinsics, CodeGenContext, CodeGenerator, @@ -16,8 +20,8 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType}, + values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, PointerValue}, AddressSpace, IntPredicate, }; use itertools::Either; @@ -565,367 +569,62 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, +fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { + match ty.get_bit_width() { + 32 => SizeVariant::Bits32, + 64 => SizeVariant::Bits64, + _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), + } +} + +fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>( ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> + size_type: IntType<'ctx>, + base_name: &str, + build_func_type: BuildFuncTypeFn, +) -> FunctionValue<'ctx> where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, + BuildFuncTypeFn: Fn() -> FunctionType<'ctx>, { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + let mut fn_name = base_name.to_owned(); + match get_size_variant(size_type) { + SizeVariant::Bits32 => { + // The original fn_name is the correct function name + } + SizeVariant::Bits64 => { + // Append "64" at the end, this is the naming convention for 64-bit + fn_name.push_str("64"); + } + } - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); + // Get (or declare then get if does not exist) the corresponding function + ctx.module.get_function(&fn_name).unwrap_or_else(|| { + let fn_type = build_func_type(); + ctx.module.add_function(&fn_name, fn_type, None) + }) +} + +fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> { + let i8_type = ctx.i8_type(); + + let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; + let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx); + struct_ty.ptr_type(AddressSpace::default()) +} + +pub fn call_nac3_ndarray_size<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, +) -> IntValue<'ctx> { + let size_type = ndarray.ty.size_type; + let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || { + size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false) + }); - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) + .build_call(function, &[ndarray.ptr.into()], "size") .unwrap() + .try_as_basic_value() + .unwrap_left() + .into_int_value() } - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.dim_sizes().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} \ No newline at end of file diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9724f6f2..b280ca5b 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,16 +1,13 @@ use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, - ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, + NDArrayType, NDArrayValue, NpArrayType, NpArrayValue, ProxyType, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, call_ndarray_calc_broadcast, - call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, - call_ndarray_calc_size, - }, + irrt::calculate_len_for_slice_range, llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, @@ -26,14 +23,140 @@ use crate::{ typedef::{FunSignature, Type, TypeEnum}, }, }; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; +use inkwell::{ + types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, + values::BasicValue, +}; use nac3parser::ast::{Operator, StrRef}; +fn memory_copy_slice<'ctx, G, T, Dst, Src>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst: Dst, + src: Src, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + Dst: TypedArrayLikeMutator<'ctx, T>, + Src: TypedArrayLikeAccessor<'ctx, T>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Check `src.size` == `dst.size`, otherwise throw an Exception + let size_ok = ctx + .builder + .build_int_compare(IntPredicate::EQ, src.size(ctx, generator), dst.size(ctx, generator), "") + .unwrap(); + ctx.make_assert( + generator, + size_ok, + "0:ValueError", + "copy slice mismatched", + [None, None, None], + ctx.current_loc, + ); + + // Copy data + let len = dst.size(ctx, generator); + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (len, false), + |generator, ctx, _, idx| { + let value = src.get_typed(ctx, generator, &idx, None); + dst.set_typed(ctx, generator, &idx, value); + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + Ok(()) +} + +fn allocate_ndarray<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_type: BasicTypeEnum<'ctx>, + in_ndims: IntValue<'ctx>, + name: &'static str, +) -> NpArrayValue<'ctx> +where + G: CodeGenerator + ?Sized, +{ + let size_type = generator.get_size_type(ctx.ctx); + let ndarray_ty = NpArrayType { elem_type, size_type }; + ndarray_ty.alloca(ctx, in_ndims, name) +} + +fn user_shape_set<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape: BasicValueEnum<'ctx>, + in_shape_ty: Type, + dst_shape: TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, +) -> Result<(), String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Check `in_shape_ty` to determine what to do determining on the user's input + match &*ctx.unifier.get_ty(in_shape_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of ints; e.g., `np.empty([600, 800, 3])` + + // NOTE: If there are no logic errors, the list's element type MUST BE int32. + + // List has to be a pointer + let BasicValueEnum::PointerValue(shape_list_ptr) = in_shape else { unreachable!() }; + + let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); + memory_copy_slice( + generator, + ctx, + dst_shape, + TypedArrayLikeAdapter::from( + shape_list.data(), + Box::new(|_ctx, value| value.into_int_value()), + Box::new(|_ctx, value| value.as_basic_value_enum()), + ), + )?; + } + TypeEnum::TTuple { ty, .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + // Tuple has to be a struct + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + let BasicValueEnum::StructValue(shape_tuple) = in_shape else { unreachable!() }; + + let ndims = ty.len(); + for dim_i in 0..ndims { + let dim = ctx + .builder + .build_extract_value(shape_tuple, dim_i as u32, format!("dim{dim_i}").as_str()) + .unwrap() + .into_int_value(); + + let idx = llvm_usize.const_int(dim_i as u64, false); + dst_shape.set_typed(ctx, generator, &idx, dim); + } + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + let shape_int = in_shape.into_int_value(); + dst_shape.set_typed(ctx, generator, &llvm_usize.const_zero(), shape_int); + } + _ => unreachable!(), + } + Ok(()) +} + // /// Creates an uninitialized `NDArray` instance. // fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -41,20 +164,20 @@ use nac3parser::ast::{Operator, StrRef}; // elem_ty: Type, // ) -> Result, String> { // let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); -// +// // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let llvm_ndarray_t = ctx // .get_llvm_type(generator, ndarray_ty) // .into_pointer_type() // .get_element_type() // .into_struct_type(); -// +// // let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; -// +// // Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) // } -// +// // /// Creates an `NDArray` instance from a dynamic shape. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -80,7 +203,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String>, // { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // // Assert that all dimensions are non-negative // let shape_len = shape_len_fn(generator, ctx, shape)?; // gen_for_callback_incrementing( @@ -91,7 +214,7 @@ use nac3parser::ast::{Operator, StrRef}; // |generator, ctx, _, i| { // let shape_dim = shape_data_fn(generator, ctx, shape, i)?; // debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); -// +// // let shape_dim_gez = ctx // .builder // .build_int_compare( @@ -101,7 +224,7 @@ use nac3parser::ast::{Operator, StrRef}; // "", // ) // .unwrap(); -// +// // ctx.make_assert( // generator, // shape_dim_gez, @@ -110,22 +233,22 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // // TODO: Disallow dim_sz > u32_MAX -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // )?; -// +// // let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; -// +// // let num_dims = shape_len_fn(generator, ctx, shape)?; // ndarray.store_ndims(ctx, generator, num_dims); -// +// // let ndarray_num_dims = ndarray.load_ndims(ctx); // ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); -// +// // // Copy the dimension sizes from shape to ndarray.dims // let shape_len = shape_len_fn(generator, ctx, shape)?; // gen_for_callback_incrementing( @@ -137,22 +260,22 @@ use nac3parser::ast::{Operator, StrRef}; // let shape_dim = shape_data_fn(generator, ctx, shape, i)?; // debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); // let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); -// +// // let ndarray_pdim = // unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; -// +// // ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // )?; -// +// // let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); -// +// // Ok(ndarray) // } -// +// // /// Creates an `NDArray` instance from a constant shape. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -164,14 +287,14 @@ use nac3parser::ast::{Operator, StrRef}; // shape: &[IntValue<'ctx>], // ) -> Result, String> { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // for &shape_dim in shape { // let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); // let shape_dim_gez = ctx // .builder // .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") // .unwrap(); -// +// // ctx.make_assert( // generator, // shape_dim_gez, @@ -180,18 +303,18 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // // TODO: Disallow dim_sz > u32_MAX // } -// +// // let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; -// +// // let num_dims = llvm_usize.const_int(shape.len() as u64, false); // ndarray.store_ndims(ctx, generator, num_dims); -// +// // let ndarray_num_dims = ndarray.load_ndims(ctx); // ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); -// +// // for (i, &shape_dim) in shape.iter().enumerate() { // let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); // let ndarray_dim = unsafe { @@ -202,15 +325,15 @@ use nac3parser::ast::{Operator, StrRef}; // None, // ) // }; -// +// // ctx.builder.build_store(ndarray_dim, shape_dim).unwrap(); // } -// +// // let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); -// +// // Ok(ndarray) // } -// +// // /// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields. // fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -220,7 +343,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> NDArrayValue<'ctx> { // let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); // assert!(llvm_ndarray_data_t.is_sized()); -// +// // let ndarray_num_elems = call_ndarray_calc_size( // generator, // ctx, @@ -228,10 +351,10 @@ use nac3parser::ast::{Operator, StrRef}; // (None, None), // ); // ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); -// +// // ndarray // } -// +// // fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, // ctx: &mut CodeGenContext<'ctx, '_>, @@ -257,7 +380,7 @@ use nac3parser::ast::{Operator, StrRef}; // unreachable!() // } // } -// +// // fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, // ctx: &mut CodeGenContext<'ctx, '_>, @@ -285,7 +408,7 @@ use nac3parser::ast::{Operator, StrRef}; // unreachable!() // } // } -// +// // /// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -307,13 +430,13 @@ use nac3parser::ast::{Operator, StrRef}; // shape: BasicValueEnum<'ctx>, // ) -> Result, String> { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // match shape { // BasicValueEnum::PointerValue(shape_list_ptr) // if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => // { // // 1. A list of ints; e.g., `np.empty([600, 800, 3])` -// +// // let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); // create_ndarray_dyn_shape( // generator, @@ -329,10 +452,10 @@ use nac3parser::ast::{Operator, StrRef}; // BasicValueEnum::StructValue(shape_tuple) => { // // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` // // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. -// +// // // Get the length/size of the tuple, which also happens to be the value of `ndims`. // let ndims = shape_tuple.get_type().count_fields(); -// +// // let mut shape = Vec::with_capacity(ndims as usize); // for dim_i in 0..ndims { // let dim = ctx @@ -340,20 +463,20 @@ use nac3parser::ast::{Operator, StrRef}; // .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) // .unwrap() // .into_int_value(); -// +// // shape.push(dim); // } // create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) // } // BasicValueEnum::IntValue(shape_int) => { // // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -// +// // create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) // } // _ => unreachable!(), // } // } -// +// // /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as // /// its input. // fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -371,14 +494,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String>, // { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let ndarray_num_elems = call_ndarray_calc_size( // generator, // ctx, // &ndarray.dim_sizes().as_slice_value(ctx, generator), // (None, None), // ); -// +// // gen_for_callback_incrementing( // generator, // ctx, @@ -386,16 +509,16 @@ use nac3parser::ast::{Operator, StrRef}; // (ndarray_num_elems, false), // |generator, ctx, _, i| { // let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; -// +// // let value = value_fn(generator, ctx, i)?; // ctx.builder.build_store(elem, value).unwrap(); -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // ) // } -// +// // /// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices // /// as its input. // fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( @@ -414,11 +537,11 @@ use nac3parser::ast::{Operator, StrRef}; // { // ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { // let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); -// +// // value_fn(generator, ctx, &indices) // }) // } -// +// // fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( // generator: &mut G, // ctx: &mut CodeGenContext<'ctx, 'a>, @@ -436,11 +559,11 @@ use nac3parser::ast::{Operator, StrRef}; // { // ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { // let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; -// +// // map_fn(generator, ctx, elem) // }) // } -// +// // /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of // /// the target `ndarray`. // fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( @@ -451,7 +574,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) { // let array_ndims = source.load_ndims(ctx); // let broadcast_size = target.load_ndims(ctx); -// +// // ctx.make_assert( // generator, // ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), @@ -461,7 +584,7 @@ use nac3parser::ast::{Operator, StrRef}; // ctx.current_loc, // ); // } -// +// // /// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value // /// with broadcast-compatible shapes. // fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( @@ -481,53 +604,53 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String>, // { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let (lhs_val, lhs_scalar) = lhs; // let (rhs_val, rhs_scalar) = rhs; -// +// // assert!( // !(lhs_scalar && rhs_scalar), // "One of the operands must be a ndarray instance: `{}`, `{}`", // lhs_val.get_type(), // rhs_val.get_type() // ); -// +// // // Assert that all ndarray operands are broadcastable to the target size // if !lhs_scalar { // let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); // ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); // } -// +// // if !rhs_scalar { // let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); // ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); // } -// +// // ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { // let lhs_elem = if lhs_scalar { // lhs_val // } else { // let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); // let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); -// +// // unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } // }; -// +// // let rhs_elem = if rhs_scalar { // rhs_val // } else { // let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); // let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); -// +// // unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } // }; -// +// // value_fn(generator, ctx, (lhs_elem, rhs_elem)) // })?; -// +// // Ok(res) // } -// +// // /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -548,17 +671,17 @@ use nac3parser::ast::{Operator, StrRef}; // ctx.primitives.str, // ]; // assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); -// +// // let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; // ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { // let value = ndarray_zero_value(generator, ctx, elem_ty); -// +// // Ok(value) // })?; -// +// // Ok(ndarray) // } -// +// // /// LLVM-typed implementation for generating the implementation for `ndarray.ones`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -579,17 +702,17 @@ use nac3parser::ast::{Operator, StrRef}; // ctx.primitives.str, // ]; // assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); -// +// // let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; // ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { // let value = ndarray_one_value(generator, ctx, elem_ty); -// +// // Ok(value) // })?; -// +// // Ok(ndarray) // } -// +// // /// LLVM-typed implementation for generating the implementation for `ndarray.full`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -605,9 +728,9 @@ use nac3parser::ast::{Operator, StrRef}; // ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { // let value = if fill_value.is_pointer_value() { // let llvm_i1 = ctx.ctx.bool_type(); -// +// // let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; -// +// // call_memcpy_generic( // ctx, // copy, @@ -615,20 +738,20 @@ use nac3parser::ast::{Operator, StrRef}; // fill_value.get_type().size_of().map(Into::into).unwrap(), // llvm_i1.const_zero(), // ); -// +// // copy.into() // } else if fill_value.is_int_value() || fill_value.is_float_value() { // fill_value // } else { // unreachable!() // }; -// +// // Ok(value) // })?; -// +// // Ok(ndarray) // } -// +// // /// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. // fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( // generator: &G, @@ -636,24 +759,24 @@ use nac3parser::ast::{Operator, StrRef}; // ty: PointerType<'ctx>, // ) -> IntValue<'ctx> { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let list_ty = ListType::from_type(ty, llvm_usize); // let list_elem_ty = list_ty.element_type(); -// +// // let ndims = llvm_usize.const_int(1, false); // match list_elem_ty { // AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { // ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) // } -// +// // AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { // todo!("Getting ndims for list[ndarray] not supported") // } -// +// // _ => ndims, // } // } -// +// // /// Returns the number of dimensions for an array-like object as an [`IntValue`]. // fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -661,20 +784,20 @@ use nac3parser::ast::{Operator, StrRef}; // value: BasicValueEnum<'ctx>, // ) -> IntValue<'ctx> { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // match value { // BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { // NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) // } -// +// // BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { // llvm_ndlist_get_ndims(generator, ctx, v.get_type()) // } -// +// // _ => llvm_usize.const_zero(), // } // } -// +// // /// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. // fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -686,9 +809,9 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result<(), String> { // let llvm_i1 = ctx.ctx.bool_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let list_elem_ty = src_lst.get_type().element_type(); -// +// // match list_elem_ty { // AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { // // The stride of elements in this dimension, i.e. the number of elements between arr[i] @@ -699,7 +822,7 @@ use nac3parser::ast::{Operator, StrRef}; // &dst_arr.dim_sizes(), // (Some(llvm_usize.const_int(dim + 1, false)), None), // ); -// +// // gen_for_range_callback( // generator, // ctx, @@ -709,17 +832,17 @@ use nac3parser::ast::{Operator, StrRef}; // |_, _| Ok(llvm_usize.const_int(1, false)), // |generator, ctx, i| { // let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); -// +// // let dst_ptr = // unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; -// +// // let nested_lst_elem = ListValue::from_ptr_val( // unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } // .into_pointer_value(), // llvm_usize, // None, // ); -// +// // ndarray_from_ndlist_impl( // generator, // ctx, @@ -728,21 +851,21 @@ use nac3parser::ast::{Operator, StrRef}; // nested_lst_elem, // dim + 1, // )?; -// +// // Ok(()) // }, // )?; // } -// +// // AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { // todo!("Not implemented for list[ndarray]") // } -// +// // _ => { // let lst_len = src_lst.load_size(ctx, None); // let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); // let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); -// +// // let cpy_len = ctx // .builder // .build_int_mul( @@ -751,7 +874,7 @@ use nac3parser::ast::{Operator, StrRef}; // "", // ) // .unwrap(); -// +// // call_memcpy_generic( // ctx, // dst_slice_ptr, @@ -761,10 +884,10 @@ use nac3parser::ast::{Operator, StrRef}; // ); // } // } -// +// // Ok(()) // } -// +// // /// LLVM-typed implementation for `ndarray.array`. // fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -776,28 +899,28 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // let llvm_i1 = ctx.ctx.bool_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); -// +// // // TODO(Derppening): Add assertions for sizes of different dimensions -// +// // // object is not a pointer - 0-dim NDArray // if !object.is_pointer_value() { // let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; -// +// // unsafe { // ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); // } -// +// // return Ok(ndarray); // } -// +// // let object = object.into_pointer_value(); -// +// // // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims // if NDArrayValue::is_instance(object, llvm_usize).is_ok() { // let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); -// +// // let ndarray = gen_if_else_expr_callback( // generator, // ctx, @@ -810,7 +933,7 @@ use nac3parser::ast::{Operator, StrRef}; // .builder // .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") // .unwrap(); -// +// // Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) // }, // |generator, ctx| { @@ -825,7 +948,7 @@ use nac3parser::ast::{Operator, StrRef}; // .builder // .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") // .unwrap(); -// +// // Ok(ctx // .builder // .build_select(ndmin_gt_ndims, ndmin, ndims, "") @@ -837,7 +960,7 @@ use nac3parser::ast::{Operator, StrRef}; // let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); // // The number of dimensions to prepend 1's to // let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); -// +// // Ok(gen_if_else_expr_callback( // generator, // ctx, @@ -854,7 +977,7 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap()) // }, // )?; -// +// // ndarray_sliced_copyto_impl( // generator, // ctx, @@ -864,28 +987,28 @@ use nac3parser::ast::{Operator, StrRef}; // 0, // &[], // )?; -// +// // Ok(Some(ndarray.as_base_value())) // }, // |_, _| Ok(Some(object.as_base_value())), // )?; -// +// // return Ok(NDArrayValue::from_ptr_val( // ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), // llvm_usize, // None, // )); // } -// +// // // Remaining case: TList // assert!(ListValue::is_instance(object, llvm_usize).is_ok()); // let object = ListValue::from_ptr_val(object, llvm_usize, None); -// +// // // The number of dimensions to prepend 1's to // let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); // let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); // let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); -// +// // let ndarray = create_ndarray_dyn_shape( // generator, // ctx, @@ -895,7 +1018,7 @@ use nac3parser::ast::{Operator, StrRef}; // let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); // let ndmin_gt_ndims = // ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); -// +// // Ok(ctx // .builder // .build_select(ndmin_gt_ndims, ndmin, ndims, "") @@ -917,11 +1040,11 @@ use nac3parser::ast::{Operator, StrRef}; // false, // ) // }; -// +// // let llvm_i8 = ctx.ctx.i8_type(); // let llvm_list_i8 = make_llvm_list(llvm_i8.into()); // let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); -// +// // // Cast list to { i8*, usize } since we only care about the size // let lst = generator // .gen_var_alloc( @@ -938,7 +1061,7 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap(), // ) // .unwrap(); -// +// // let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); // gen_for_range_callback( // generator, @@ -950,7 +1073,7 @@ use nac3parser::ast::{Operator, StrRef}; // |generator, ctx, _| { // let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) // .ptr_type(AddressSpace::default()); -// +// // let this_dim = ctx // .builder // .build_load(lst, "") @@ -959,9 +1082,9 @@ use nac3parser::ast::{Operator, StrRef}; // .map(BasicValueEnum::into_pointer_value) // .unwrap(); // let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); -// +// // // TODO: Assert this_dim.sz != 0 -// +// // let next_dim = unsafe { // this_dim.data().get_unchecked( // ctx, @@ -977,11 +1100,11 @@ use nac3parser::ast::{Operator, StrRef}; // ctx.builder.build_bitcast(next_dim, llvm_plist_i8, "").unwrap(), // ) // .unwrap(); -// +// // Ok(()) // }, // )?; -// +// // let lst = ListValue::from_ptr_val( // ctx.builder // .build_load(lst, "") @@ -990,7 +1113,7 @@ use nac3parser::ast::{Operator, StrRef}; // llvm_usize, // None, // ); -// +// // Ok(Some(lst.load_size(ctx, None))) // }, // )? @@ -998,7 +1121,7 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap()) // }, // )?; -// +// // ndarray_from_ndlist_impl( // generator, // ctx, @@ -1007,10 +1130,10 @@ use nac3parser::ast::{Operator, StrRef}; // object, // 0, // )?; -// +// // Ok(ndarray) // } -// +// // /// LLVM-typed implementation for generating the implementation for `ndarray.eye`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -1024,12 +1147,12 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // let llvm_i32 = ctx.ctx.i32_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); // let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); -// +// // let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; -// +// // ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { // let (row, col) = unsafe { // ( @@ -1037,7 +1160,7 @@ use nac3parser::ast::{Operator, StrRef}; // indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), // ) // }; -// +// // let col_with_offset = ctx // .builder // .build_int_add( @@ -1048,18 +1171,18 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap(); // let is_on_diag = // ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); -// +// // let zero = ndarray_zero_value(generator, ctx, elem_ty); // let one = ndarray_one_value(generator, ctx, elem_ty); -// +// // let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); -// +// // Ok(value) // })?; -// +// // Ok(ndarray) // } -// +// // /// Copies a slice of an [`NDArrayValue`] to another. // /// // /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz` @@ -1083,7 +1206,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result<(), String> { // let llvm_i1 = ctx.ctx.bool_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // // If there are no (remaining) slice expressions, memcpy the entire dimension // if slices.is_empty() { // let stride = call_ndarray_calc_size( @@ -1094,12 +1217,12 @@ use nac3parser::ast::{Operator, StrRef}; // ); // let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); // let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); -// +// // call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); -// +// // return Ok(()); // } -// +// // // The stride of elements in this dimension, i.e. the number of elements between arr[i] and // // arr[i + 1] in this dimension // let src_stride = call_ndarray_calc_size( @@ -1114,15 +1237,15 @@ use nac3parser::ast::{Operator, StrRef}; // &dst_arr.dim_sizes(), // (Some(llvm_usize.const_int(dim + 1, false)), None), // ); -// +// // let (start, stop, step) = slices[0]; // let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); // let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); // let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); -// +// // let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); // ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); -// +// // gen_for_range_callback( // generator, // ctx, @@ -1136,14 +1259,14 @@ use nac3parser::ast::{Operator, StrRef}; // let dst_i = // ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); // let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); -// +// // let (src_ptr, dst_ptr) = unsafe { // ( // ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), // ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), // ) // }; -// +// // ndarray_sliced_copyto_impl( // generator, // ctx, @@ -1153,20 +1276,20 @@ use nac3parser::ast::{Operator, StrRef}; // dim + 1, // &slices[1..], // )?; -// +// // let dst_i = // ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); // let dst_i_add1 = // ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); // ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); -// +// // Ok(()) // }, // )?; -// +// // Ok(()) // } -// +// // /// Copies a [`NDArrayValue`] using slices. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -1181,7 +1304,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // let llvm_i32 = ctx.ctx.i32_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let ndarray = if slices.is_empty() { // create_ndarray_dyn_shape( // generator, @@ -1196,10 +1319,10 @@ use nac3parser::ast::{Operator, StrRef}; // } else { // let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; // ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); -// +// // let ndims = this.load_ndims(ctx); // ndarray.create_dim_sizes(ctx, llvm_usize, ndims); -// +// // // Populate the first slices.len() dimensions by computing the size of each dim slice // for (i, (start, stop, step)) in slices.iter().enumerate() { // // HACK: workaround calculate_len_for_slice_range requiring exclusive stop @@ -1224,11 +1347,11 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .map(BasicValueEnum::into_int_value) // .unwrap(); -// +// // let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); // let slice_len = // ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); -// +// // unsafe { // ndarray.dim_sizes().set_typed_unchecked( // ctx, @@ -1238,7 +1361,7 @@ use nac3parser::ast::{Operator, StrRef}; // ); // } // } -// +// // // Populate the rest by directly copying the dim size from the source array // gen_for_callback_incrementing( // generator, @@ -1250,16 +1373,16 @@ use nac3parser::ast::{Operator, StrRef}; // let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); // ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); // } -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // ) // .unwrap(); -// +// // ndarray_init_data(generator, ctx, elem_ty, ndarray) // }; -// +// // ndarray_sliced_copyto_impl( // generator, // ctx, @@ -1269,10 +1392,10 @@ use nac3parser::ast::{Operator, StrRef}; // 0, // slices, // )?; -// +// // Ok(ndarray) // } -// +// // /// LLVM-typed implementation for generating the implementation for `ndarray.copy`. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -1284,7 +1407,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) // } -// +// // pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( // generator: &mut G, // ctx: &mut CodeGenContext<'ctx, 'a>, @@ -1314,14 +1437,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .unwrap() // }); -// +// // ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { // map_fn(generator, ctx, elem) // })?; -// +// // Ok(res) // } -// +// // /// LLVM-typed implementation for computing elementwise binary operations on two input operands. // /// // /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output @@ -1359,26 +1482,26 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String>, // { // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let (lhs_val, lhs_scalar) = lhs; // let (rhs_val, rhs_scalar) = rhs; -// +// // assert!( // !(lhs_scalar && rhs_scalar), // "One of the operands must be a ndarray instance: `{}`, `{}`", // lhs_val.get_type(), // rhs_val.get_type() // ); -// +// // let ndarray = res.unwrap_or_else(|| { // if lhs_scalar && rhs_scalar { // let lhs_val = // NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); // let rhs_val = // NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); -// +// // let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); -// +// // create_ndarray_dyn_shape( // generator, // ctx, @@ -1396,7 +1519,7 @@ use nac3parser::ast::{Operator, StrRef}; // llvm_usize, // None, // ); -// +// // create_ndarray_dyn_shape( // generator, // ctx, @@ -1410,14 +1533,14 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap() // } // }); -// +// // ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { // value_fn(generator, ctx, elems) // })?; -// +// // Ok(ndarray) // } -// +// // /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. // /// // /// * `elem_ty` - The element type of the `NDArray`. @@ -1433,11 +1556,11 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // let llvm_i32 = ctx.ctx.i32_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // if cfg!(debug_assertions) { // let lhs_ndims = lhs.load_ndims(ctx); // let rhs_ndims = rhs.load_ndims(ctx); -// +// // // lhs.ndims == 2 // ctx.make_assert( // generator, @@ -1449,7 +1572,7 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // // rhs.ndims == 2 // ctx.make_assert( // generator, @@ -1461,7 +1584,7 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // if let Some(res) = res { // let res_ndims = res.load_ndims(ctx); // let res_dim0 = unsafe { @@ -1486,7 +1609,7 @@ use nac3parser::ast::{Operator, StrRef}; // None, // ) // }; -// +// // // res.ndims == 2 // ctx.make_assert( // generator, @@ -1503,7 +1626,7 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // // res.dims[0] == lhs.dims[0] // ctx.make_assert( // generator, @@ -1513,7 +1636,7 @@ use nac3parser::ast::{Operator, StrRef}; // [None, None, None], // ctx.current_loc, // ); -// +// // // res.dims[1] == rhs.dims[0] // ctx.make_assert( // generator, @@ -1525,7 +1648,7 @@ use nac3parser::ast::{Operator, StrRef}; // ); // } // } -// +// // if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { // let lhs_dim1 = unsafe { // lhs.dim_sizes().get_typed_unchecked( @@ -1538,7 +1661,7 @@ use nac3parser::ast::{Operator, StrRef}; // let rhs_dim0 = unsafe { // rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) // }; -// +// // // lhs.dims[1] == rhs.dims[0] // ctx.make_assert( // generator, @@ -1549,13 +1672,13 @@ use nac3parser::ast::{Operator, StrRef}; // ctx.current_loc, // ); // } -// +// // let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { // ndarray_copy_impl(generator, ctx, elem_ty, lhs)? // } else { // lhs // }; -// +// // let ndarray = res.unwrap_or_else(|| { // create_ndarray_dyn_shape( // generator, @@ -1599,9 +1722,9 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .unwrap() // }); -// +// // let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); -// +// // ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { // llvm_intrinsics::call_expect( // ctx, @@ -1609,7 +1732,7 @@ use nac3parser::ast::{Operator, StrRef}; // idx.size(ctx, generator), // None, // ); -// +// // let common_dim = { // let lhs_idx1 = unsafe { // lhs.dim_sizes().get_typed_unchecked( @@ -1622,28 +1745,28 @@ use nac3parser::ast::{Operator, StrRef}; // let rhs_idx0 = unsafe { // rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) // }; -// +// // let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); -// +// // ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() // }; -// +// // let idx0 = unsafe { // let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); -// +// // ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() // }; // let idx1 = unsafe { // let idx1 = // idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); -// +// // ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() // }; -// +// // let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; // let result_identity = ndarray_zero_value(generator, ctx, elem_ty); // ctx.builder.build_store(result_addr, result_identity).unwrap(); -// +// // gen_for_callback_incrementing( // generator, // ctx, @@ -1651,18 +1774,18 @@ use nac3parser::ast::{Operator, StrRef}; // (common_dim, false), // |generator, ctx, _, i| { // let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); -// +// // let ab_idx = generator.gen_array_var_alloc( // ctx, // llvm_i32.into(), // llvm_usize.const_int(2, false), // None, // )?; -// +// // let a = unsafe { // ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); // ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); -// +// // lhs.data().get_unchecked(ctx, generator, &ab_idx, None) // }; // let b = unsafe { @@ -1673,10 +1796,10 @@ use nac3parser::ast::{Operator, StrRef}; // &llvm_usize.const_int(1, false), // idx1.into(), // ); -// +// // rhs.data().get_unchecked(ctx, generator, &ab_idx, None) // }; -// +// // let a_mul_b = gen_binop_expr_with_values( // generator, // ctx, @@ -1687,7 +1810,7 @@ use nac3parser::ast::{Operator, StrRef}; // )? // .unwrap() // .to_basic_value_enum(ctx, generator, elem_ty)?; -// +// // let result = ctx.builder.build_load(result_addr, "").unwrap(); // let result = gen_binop_expr_with_values( // generator, @@ -1700,19 +1823,19 @@ use nac3parser::ast::{Operator, StrRef}; // .unwrap() // .to_basic_value_enum(ctx, generator, elem_ty)?; // ctx.builder.build_store(result_addr, result).unwrap(); -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // )?; -// +// // let result = ctx.builder.build_load(result_addr, "").unwrap(); // Ok(result) // })?; -// +// // Ok(ndarray) // } -// +// // /// Generates LLVM IR for `ndarray.empty`. // pub fn gen_ndarray_empty<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1723,14 +1846,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert_eq!(args.len(), 1); -// +// // let shape_ty = fun.0.args[0].ty; // let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; -// +// // call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.zeros`. // pub fn gen_ndarray_zeros<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1741,14 +1864,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert_eq!(args.len(), 1); -// +// // let shape_ty = fun.0.args[0].ty; // let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; -// +// // call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.ones`. // pub fn gen_ndarray_ones<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1759,14 +1882,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert_eq!(args.len(), 1); -// +// // let shape_ty = fun.0.args[0].ty; // let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; -// +// // call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.full`. // pub fn gen_ndarray_full<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1777,17 +1900,17 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert_eq!(args.len(), 2); -// +// // let shape_ty = fun.0.args[0].ty; // let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; // let fill_value_ty = fun.0.args[1].ty; // let fill_value_arg = // args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; -// +// // call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) // .map(NDArrayValue::into) // } -// +// // pub fn gen_ndarray_array<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, // obj: &Option<(Type, ValueEnum<'ctx>)>, @@ -1797,13 +1920,13 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert!(matches!(args.len(), 1..=3)); -// +// // let obj_ty = fun.0.args[0].ty; // let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { // TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { // unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 // } -// +// // TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { // let mut ty = *params.iter().next().unwrap().1; // while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) @@ -1811,16 +1934,16 @@ use nac3parser::ast::{Operator, StrRef}; // if *obj_id != PrimDef::List.id() { // break; // } -// +// // ty = *params.iter().next().unwrap().1; // } // ty // } -// +// // _ => obj_ty, // }; // let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; -// +// // let copy_arg = if let Some(arg) = // args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) // { @@ -1833,7 +1956,7 @@ use nac3parser::ast::{Operator, StrRef}; // fun.0.args[1].ty, // ) // }; -// +// // let ndmin_arg = if let Some(arg) = // args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) // { @@ -1846,7 +1969,7 @@ use nac3parser::ast::{Operator, StrRef}; // fun.0.args[2].ty, // ) // }; -// +// // call_ndarray_array_impl( // generator, // context, @@ -1857,7 +1980,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.eye`. // pub fn gen_ndarray_eye<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1868,10 +1991,10 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert!(matches!(args.len(), 1..=3)); -// +// // let nrows_ty = fun.0.args[0].ty; // let nrows_arg = args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)?; -// +// // let ncols_ty = fun.0.args[1].ty; // let ncols_arg = if let Some(arg) = // args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) @@ -1880,7 +2003,7 @@ use nac3parser::ast::{Operator, StrRef}; // } else { // args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty) // }?; -// +// // let offset_ty = fun.0.args[2].ty; // let offset_arg = if let Some(arg) = // args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) @@ -1893,7 +2016,7 @@ use nac3parser::ast::{Operator, StrRef}; // offset_ty, // )) // }?; -// +// // call_ndarray_eye_impl( // generator, // context, @@ -1904,7 +2027,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.identity`. // pub fn gen_ndarray_identity<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1915,12 +2038,12 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_none()); // assert_eq!(args.len(), 1); -// +// // let llvm_usize = generator.get_size_type(context.ctx); -// +// // let n_ty = fun.0.args[0].ty; // let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; -// +// // call_ndarray_eye_impl( // generator, // context, @@ -1931,7 +2054,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.copy`. // pub fn gen_ndarray_copy<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1942,14 +2065,14 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result, String> { // assert!(obj.is_some()); // assert!(args.is_empty()); -// +// // let llvm_usize = generator.get_size_type(context.ctx); -// +// // let this_ty = obj.as_ref().unwrap().0; // let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); // let this_arg = // obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; -// +// // ndarray_copy_impl( // generator, // context, @@ -1958,7 +2081,7 @@ use nac3parser::ast::{Operator, StrRef}; // ) // .map(NDArrayValue::into) // } -// +// // /// Generates LLVM IR for `ndarray.fill`. // pub fn gen_ndarray_fill<'ctx>( // context: &mut CodeGenContext<'ctx, '_>, @@ -1969,9 +2092,9 @@ use nac3parser::ast::{Operator, StrRef}; // ) -> Result<(), String> { // assert!(obj.is_some()); // assert_eq!(args.len(), 1); -// +// // let llvm_usize = generator.get_size_type(context.ctx); -// +// // let this_ty = obj.as_ref().unwrap().0; // let this_arg = obj // .as_ref() @@ -1982,7 +2105,7 @@ use nac3parser::ast::{Operator, StrRef}; // .into_pointer_value(); // let value_ty = fun.0.args[0].ty; // let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; -// +// // ndarray_fill_flattened( // generator, // context, @@ -1990,9 +2113,9 @@ use nac3parser::ast::{Operator, StrRef}; // |generator, ctx, _| { // let value = if value_arg.is_pointer_value() { // let llvm_i1 = ctx.ctx.bool_type(); -// +// // let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; -// +// // call_memcpy_generic( // ctx, // copy, @@ -2000,18 +2123,18 @@ use nac3parser::ast::{Operator, StrRef}; // value_arg.get_type().size_of().map(Into::into).unwrap(), // llvm_i1.const_zero(), // ); -// +// // copy.into() // } else if value_arg.is_int_value() || value_arg.is_float_value() { // value_arg // } else { // unreachable!() // }; -// +// // Ok(value) // }, // )?; -// +// // Ok(()) // } -// \ No newline at end of file +// diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 4ffd60b1..474962a7 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -23,3 +23,4 @@ pub mod codegen; pub mod symbol_resolver; pub mod toplevel; pub mod typecheck; +pub mod util; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e49748d9..f3c3d607 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,5 +1,6 @@ use std::iter::once; +use crate::util::SizeVariant; use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ @@ -278,19 +279,10 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .collect() } -/// A helper enum used by [`BuiltinBuilder`] -#[derive(Clone, Copy)] -enum SizeVariant { - Bits32, - Bits64, -} - -impl SizeVariant { - fn of_int(self, primitives: &PrimitiveStore) -> Type { - match self { - SizeVariant::Bits32 => primitives.int32, - SizeVariant::Bits64 => primitives.int64, - } +fn size_variant_to_int_type(variant: SizeVariant, primitives: &PrimitiveStore) -> Type { + match variant { + SizeVariant::Bits32 => primitives.int32, + SizeVariant::Bits64 => primitives.int64, } } @@ -1061,7 +1053,7 @@ impl<'a> BuiltinBuilder<'a> { ); // The size variant of the function determines the size of the returned int. - let int_sized = size_variant.of_int(self.primitives); + let int_sized = size_variant_to_int_type(size_variant, self.primitives); let ndarray_int_sized = make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); @@ -1086,7 +1078,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives); Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), ) @@ -1127,7 +1119,7 @@ impl<'a> BuiltinBuilder<'a> { make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); // The size variant of the function determines the type of int returned - let int_sized = size_variant.of_int(self.primitives); + let int_sized = size_variant_to_int_type(size_variant, self.primitives); let ndarray_int_sized = make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); @@ -1150,7 +1142,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives); let func = match kind { Kind::Ceil => builtin_fns::call_ceil, Kind::Floor => builtin_fns::call_floor, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 7dfd8373..864ccdd1 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -34,6 +34,7 @@ pub mod numpy; pub mod type_annotation; use composer::*; use type_annotation::*; + #[cfg(test)] mod test; diff --git a/nac3core/src/util.rs b/nac3core/src/util.rs new file mode 100644 index 00000000..99bc134f --- /dev/null +++ b/nac3core/src/util.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SizeVariant { + Bits32, + Bits64, +}