Compare commits
5 Commits
master
...
ndstrides-
Author | SHA1 | Date |
---|---|---|
David Mak | acd976289f | |
David Mak | ebeb4f6dca | |
David Mak | c234684e84 | |
David Mak | 4d9ed9376b | |
David Mak | a481add9af |
10
flake.nix
10
flake.nix
|
@ -107,18 +107,18 @@
|
||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "sipyco";
|
repo = "sipyco";
|
||||||
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
|
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
|
||||||
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
|
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
|
||||||
})
|
})
|
||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "artiq";
|
repo = "artiq";
|
||||||
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
|
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
|
||||||
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
|
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
|
||||||
})
|
})
|
||||||
];
|
];
|
||||||
buildInputs = [
|
buildInputs = [
|
||||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
|
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
|
||||||
pkgs.llvmPackages_14.llvm.out
|
pkgs.llvmPackages_14.llvm.out
|
||||||
];
|
];
|
||||||
phases = [ "buildPhase" "installPhase" ];
|
phases = [ "buildPhase" "installPhase" ];
|
||||||
|
|
|
@ -206,7 +206,7 @@ class Core:
|
||||||
embedding = EmbeddingMap()
|
embedding = EmbeddingMap()
|
||||||
|
|
||||||
if allow_registration:
|
if allow_registration:
|
||||||
compiler.analyze(registered_functions, registered_classes, set())
|
compiler.analyze(registered_functions, registered_classes)
|
||||||
allow_registration = False
|
allow_registration = False
|
||||||
|
|
||||||
if hasattr(method, "__self__"):
|
if hasattr(method, "__self__"):
|
||||||
|
|
|
@ -15,9 +15,10 @@ use pyo3::{
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{
|
codegen::{
|
||||||
expr::{destructure_range, gen_call},
|
expr::{destructure_range, gen_call},
|
||||||
irrt::call_ndarray_calc_size,
|
irrt::ndarray::call_ndarray_calc_size,
|
||||||
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
||||||
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
||||||
|
type_aligned_alloca,
|
||||||
types::{NDArrayType, ProxyType},
|
types::{NDArrayType, ProxyType},
|
||||||
values::{
|
values::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
|
||||||
|
@ -642,27 +643,12 @@ fn format_rpc_ret<'ctx>(
|
||||||
// (4 + 4 * ndims) bytes with 8-byte alignment
|
// (4 + 4 * ndims) bytes with 8-byte alignment
|
||||||
let sizeof_dims =
|
let sizeof_dims =
|
||||||
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||||
let unaligned_buffer_size =
|
let buffer_size =
|
||||||
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
||||||
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
|
||||||
|
|
||||||
let stackptr = call_stacksave(ctx, None);
|
let stackptr = call_stacksave(ctx, None);
|
||||||
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
let buffer =
|
||||||
let buffer = ctx
|
type_aligned_alloca(generator, ctx, llvm_i8_8, buffer_size, Some("rpc.buffer"));
|
||||||
.builder
|
|
||||||
.build_array_alloca(
|
|
||||||
llvm_i8_8,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"rpc.buffer",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let buffer = ctx
|
|
||||||
.builder
|
|
||||||
.build_bit_cast(buffer, llvm_pi8, "")
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap();
|
|
||||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
||||||
|
|
||||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||||
|
@ -735,7 +721,9 @@ fn format_rpc_ret<'ctx>(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
unsafe {
|
||||||
|
ndarray.create_data(generator, ctx, num_elements);
|
||||||
|
}
|
||||||
|
|
||||||
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||||
let ndarray_data_i8 =
|
let ndarray_data_i8 =
|
||||||
|
@ -1376,6 +1364,7 @@ fn polymorphic_print<'ctx>(
|
||||||
let val = NDArrayValue::from_pointer_value(
|
let val = NDArrayValue::from_pointer_value(
|
||||||
value.into_pointer_value(),
|
value.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
|
@ -24,7 +24,7 @@ use parking_lot::{Mutex, RwLock};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
create_exception, exceptions,
|
create_exception, exceptions,
|
||||||
prelude::*,
|
prelude::*,
|
||||||
types::{PyBytes, PyDict, PyNone, PySet},
|
types::{PyBytes, PyDict, PySet},
|
||||||
};
|
};
|
||||||
use tempfile::{self, TempDir};
|
use tempfile::{self, TempDir};
|
||||||
|
|
||||||
|
@ -142,32 +142,14 @@ impl Nac3 {
|
||||||
module: &PyObject,
|
module: &PyObject,
|
||||||
registered_class_ids: &HashSet<u64>,
|
registered_class_ids: &HashSet<u64>,
|
||||||
) -> PyResult<()> {
|
) -> PyResult<()> {
|
||||||
let (module_name, source_file, source) =
|
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
||||||
Python::with_gil(|py| -> PyResult<(String, String, String)> {
|
|
||||||
let module: &PyAny = module.extract(py)?;
|
let module: &PyAny = module.extract(py)?;
|
||||||
let source_file = module.getattr("__file__");
|
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
|
||||||
let (source_file, source) = if let Ok(source_file) = source_file {
|
|
||||||
let source_file = source_file.extract()?;
|
|
||||||
(
|
|
||||||
source_file,
|
|
||||||
fs::read_to_string(&source_file).map_err(|e| {
|
|
||||||
exceptions::PyIOError::new_err(format!(
|
|
||||||
"failed to read input file: {e}"
|
|
||||||
))
|
|
||||||
})?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// kernels submitted by content have no file
|
|
||||||
// but still can provide source by StringLoader
|
|
||||||
let get_src_fn = module
|
|
||||||
.getattr("__loader__")?
|
|
||||||
.extract::<PyObject>()?
|
|
||||||
.getattr(py, "get_source")?;
|
|
||||||
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
|
|
||||||
};
|
|
||||||
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let source = fs::read_to_string(&source_file).map_err(|e| {
|
||||||
|
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
|
||||||
|
})?;
|
||||||
let parser_result = parse_program(&source, source_file.into())
|
let parser_result = parse_program(&source, source_file.into())
|
||||||
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
||||||
|
|
||||||
|
@ -1090,12 +1072,7 @@ impl Nac3 {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze(
|
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
||||||
&mut self,
|
|
||||||
functions: &PySet,
|
|
||||||
classes: &PySet,
|
|
||||||
content_modules: &PySet,
|
|
||||||
) -> PyResult<()> {
|
|
||||||
let (modules, class_ids) =
|
let (modules, class_ids) =
|
||||||
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||||
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||||
|
@ -1105,22 +1082,14 @@ impl Nac3 {
|
||||||
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||||
|
|
||||||
for function in functions {
|
for function in functions {
|
||||||
let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
|
let module = getmodule_fn.call1((function,))?.extract()?;
|
||||||
if !module.is_none(py) {
|
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
for class in classes {
|
for class in classes {
|
||||||
let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
|
let module = getmodule_fn.call1((class,))?.extract()?;
|
||||||
if !module.is_none(py) {
|
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
}
|
|
||||||
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||||
}
|
}
|
||||||
for module in content_modules {
|
|
||||||
let module: PyObject = module.extract()?;
|
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module.into());
|
|
||||||
}
|
|
||||||
Ok((modules, class_ids))
|
Ok((modules, class_ids))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
|
@ -3,3 +3,5 @@
|
||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
|
#include "irrt/ndarray/basic.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
|
// TODO: To be deleted since NDArray with strides is done.
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
|
|
@ -0,0 +1,342 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/debug.hpp"
|
||||||
|
#include "irrt/exception.hpp"
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace basic {
|
||||||
|
/**
|
||||||
|
* @brief Assert that `shape` does not contain negative dimensions.
|
||||||
|
*
|
||||||
|
* @param ndims Number of dimensions in `shape`
|
||||||
|
* @param shape The shape to check on
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
if (shape[axis] < 0) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"negative dimensions are not allowed; axis {0} "
|
||||||
|
"has dimension {1}",
|
||||||
|
axis, shape[axis], NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void assert_output_shape_same(SizeT ndarray_ndims,
|
||||||
|
const SizeT* ndarray_shape,
|
||||||
|
SizeT output_ndims,
|
||||||
|
const SizeT* output_shape) {
|
||||||
|
if (ndarray_ndims != output_ndims) {
|
||||||
|
// There is no corresponding NumPy error message like this.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
||||||
|
output_ndims, ndarray_ndims, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
||||||
|
if (ndarray_shape[axis] != output_shape[axis]) {
|
||||||
|
// There is no corresponding NumPy error message like this.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"Mismatched dimensions on axis {0}, output has "
|
||||||
|
"dimension {1}, but destination ndarray has dimension {2}.",
|
||||||
|
axis, output_shape[axis], ndarray_shape[axis]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the number of elements of an ndarray given its shape.
|
||||||
|
*
|
||||||
|
* @param ndims Number of dimensions in `shape`
|
||||||
|
* @param shape The shape of the ndarray
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||||
|
SizeT size = 1;
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++)
|
||||||
|
size *= shape[axis];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
||||||
|
*
|
||||||
|
* @param ndims Number of elements in `shape` and `indices`
|
||||||
|
* @param shape The shape of the ndarray
|
||||||
|
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
||||||
|
* @param nth The index of the element of interest.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
SizeT axis = ndims - i - 1;
|
||||||
|
SizeT dim = shape[axis];
|
||||||
|
|
||||||
|
indices[axis] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the number of elements of an `ndarray`
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.size`
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT size(const NDArray<SizeT>* ndarray) {
|
||||||
|
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return of the number of its content of an `ndarray`.
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.nbytes`.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
||||||
|
return size(ndarray) * ndarray->itemsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.__len__`.
|
||||||
|
*
|
||||||
|
* @param dst_length The length.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT len(const NDArray<SizeT>* ndarray) {
|
||||||
|
if (ndarray->ndims != 0) {
|
||||||
|
return ndarray->shape[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// numpy prohibits `__len__` on unsized objects
|
||||||
|
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
||||||
|
*
|
||||||
|
* You may want to see ndarray's rules for C-contiguity:
|
||||||
|
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
||||||
|
// References:
|
||||||
|
// - tinynumpy's implementation:
|
||||||
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
||||||
|
// - ndarray's flags["C_CONTIGUOUS"]:
|
||||||
|
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
||||||
|
// - ndarray's rules for C-contiguity:
|
||||||
|
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||||
|
|
||||||
|
// From
|
||||||
|
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
||||||
|
//
|
||||||
|
// The traditional rule is that for an array to be flagged as C contiguous,
|
||||||
|
// the following must hold:
|
||||||
|
//
|
||||||
|
// strides[-1] == itemsize
|
||||||
|
// strides[i] == shape[i+1] * strides[i + 1]
|
||||||
|
// [...]
|
||||||
|
// According to these rules, a 0- or 1-dimensional array is either both
|
||||||
|
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
||||||
|
// can be C- or F- contiguous, or neither, but not both. Though there
|
||||||
|
// there are exceptions for arrays with zero or one item, in the first
|
||||||
|
// case the check is relaxed up to and including the first dimension
|
||||||
|
// with shape[i] == 0. In the second case `strides == itemsize` will
|
||||||
|
// can be true for all dimensions and both flags are set.
|
||||||
|
|
||||||
|
if (ndarray->ndims == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis_i = ndarray->ndims - i - 1;
|
||||||
|
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
||||||
|
*
|
||||||
|
* This function does no bound check.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
||||||
|
void* element = ndarray->data;
|
||||||
|
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
||||||
|
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
||||||
|
*
|
||||||
|
* This function does no bound check.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
||||||
|
void* element = ndarray->data;
|
||||||
|
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis = ndarray->ndims - i - 1;
|
||||||
|
SizeT dim = ndarray->shape[axis];
|
||||||
|
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
||||||
|
*
|
||||||
|
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis = ndarray->ndims - i - 1;
|
||||||
|
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
||||||
|
stride_product *= ndarray->shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set an element in `ndarray`.
|
||||||
|
*
|
||||||
|
* @param pelement Pointer to the element in `ndarray` to be set.
|
||||||
|
* @param pvalue Pointer to the value `pelement` will be set to.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
||||||
|
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
||||||
|
*
|
||||||
|
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
||||||
|
// TODO: Handle overlapping.
|
||||||
|
|
||||||
|
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
||||||
|
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||||
|
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||||
|
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace basic
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::basic;
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
||||||
|
assert_shape_no_negative(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
||||||
|
assert_shape_no_negative(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
||||||
|
const int32_t* ndarray_shape,
|
||||||
|
int32_t output_ndims,
|
||||||
|
const int32_t* output_shape) {
|
||||||
|
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
||||||
|
const int64_t* ndarray_shape,
|
||||||
|
int64_t output_ndims,
|
||||||
|
const int64_t* output_shape) {
|
||||||
|
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||||
|
return size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||||
|
return size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||||
|
return nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||||
|
return nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
||||||
|
return len(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
||||||
|
return len(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
||||||
|
return is_c_contiguous(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
||||||
|
return is_c_contiguous(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
||||||
|
return get_nth_pelement(ndarray, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
||||||
|
return get_nth_pelement(ndarray, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
||||||
|
return get_pelement_by_indices(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||||
|
return get_pelement_by_indices(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||||
|
set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||||
|
set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/**
|
||||||
|
* @brief The NDArray object
|
||||||
|
*
|
||||||
|
* Official numpy implementation:
|
||||||
|
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
||||||
|
*
|
||||||
|
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
||||||
|
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
||||||
|
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
||||||
|
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
||||||
|
* `data`. There are also minor differences in the struct layout.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
struct NDArray {
|
||||||
|
/**
|
||||||
|
* @brief The number of bytes of a single element in `data`.
|
||||||
|
*/
|
||||||
|
SizeT itemsize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The number of dimensions of this shape.
|
||||||
|
*/
|
||||||
|
SizeT ndims;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The NDArray shape, with length equal to `ndims`.
|
||||||
|
*
|
||||||
|
* Note that it may contain 0.
|
||||||
|
*/
|
||||||
|
SizeT* shape;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Array strides, with length equal to `ndims`
|
||||||
|
*
|
||||||
|
* The stride values are in units of bytes, not number of elements.
|
||||||
|
*
|
||||||
|
* Note that `strides` can have negative values or contain 0.
|
||||||
|
*/
|
||||||
|
SizeT* strides;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The underlying data this `ndarray` is pointing to.
|
||||||
|
*/
|
||||||
|
void* data;
|
||||||
|
};
|
||||||
|
} // namespace
|
|
@ -74,6 +74,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let arg = NDArrayValue::from_pointer_value(
|
let arg = NDArrayValue::from_pointer_value(
|
||||||
arg.into_pointer_value(),
|
arg.into_pointer_value(),
|
||||||
ctx.get_llvm_type(generator, elem_ty),
|
ctx.get_llvm_type(generator, elem_ty),
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -153,7 +154,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -216,7 +217,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int64,
|
ctx.primitives.int64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -295,7 +296,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint32,
|
ctx.primitives.uint32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -363,7 +364,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint64,
|
ctx.primitives.uint64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -430,7 +431,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -477,7 +478,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -518,7 +519,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -584,7 +585,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| {
|
|generator, ctx, val| {
|
||||||
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
||||||
|
|
||||||
|
@ -639,7 +640,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -690,7 +691,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -921,8 +922,9 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
let n_sz =
|
||||||
|
irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx
|
let n_sz_eqz = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -1135,7 +1137,7 @@ where
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, elem_val| {
|
|generator, ctx, elem_val| {
|
||||||
helper_call_numpy_unary_elementwise(
|
helper_call_numpy_unary_elementwise(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1974,7 +1976,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2016,7 +2018,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2066,7 +2068,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2121,7 +2123,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2163,7 +2165,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2206,7 +2208,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2259,7 +2261,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
let n2_array = numpy::create_ndarray_const_shape(
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2354,7 +2356,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2397,7 +2399,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
|
|
@ -32,7 +32,7 @@ use super::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
},
|
},
|
||||||
types::{ListType, ProxyType},
|
types::{ListType, NDArrayType, ProxyType},
|
||||||
values::{
|
values::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
|
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
|
||||||
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||||
|
@ -43,7 +43,7 @@ use crate::{
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::unpack_ndarray_var_tys,
|
||||||
DefinitionId, TopLevelDef,
|
DefinitionId, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -1570,12 +1570,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let left_val = NDArrayValue::from_pointer_value(
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
left_val.into_pointer_value(),
|
left_val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype1,
|
llvm_ndarray_dtype1,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let right_val = NDArrayValue::from_pointer_value(
|
let right_val = NDArrayValue::from_pointer_value(
|
||||||
right_val.into_pointer_value(),
|
right_val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype2,
|
llvm_ndarray_dtype2,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1631,6 +1633,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let ndarray_val = NDArrayValue::from_pointer_value(
|
let ndarray_val = NDArrayValue::from_pointer_value(
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
llvm_ndarray_dtype,
|
llvm_ndarray_dtype,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1828,6 +1831,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let val = NDArrayValue::from_pointer_value(
|
let val = NDArrayValue::from_pointer_value(
|
||||||
val.into_pointer_value(),
|
val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype,
|
llvm_ndarray_dtype,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1926,6 +1930,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let left_val = NDArrayValue::from_pointer_value(
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
lhs.into_pointer_value(),
|
lhs.into_pointer_value(),
|
||||||
llvm_ndarray_dtype1,
|
llvm_ndarray_dtype1,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -2590,14 +2595,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
_ => 1,
|
_ => 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
|
|
||||||
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
let ndarray_ty =
|
|
||||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||||
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
|
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
|
||||||
|
|
||||||
|
@ -2792,25 +2789,15 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||||
|
|
||||||
|
let num_dims = v.load_ndims(ctx);
|
||||||
|
let num_dims = ctx.builder
|
||||||
|
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
// elements over
|
// elements over
|
||||||
let subscripted_ndarray =
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t)
|
||||||
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
.construct_uninitialized(generator, ctx, num_dims, None);
|
||||||
let ndarray = NDArrayValue::from_pointer_value(
|
|
||||||
subscripted_ndarray,
|
|
||||||
llvm_ndarray_data_t,
|
|
||||||
llvm_usize,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let num_dims = v.load_ndims(ctx);
|
|
||||||
ndarray.store_ndims(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
@ -2842,7 +2829,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
llvm_i1.const_zero(),
|
llvm_i1.const_zero(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = ndarray::call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.shape().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
|
@ -2852,7 +2839,9 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
.builder
|
.builder
|
||||||
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
unsafe {
|
||||||
|
ndarray.create_data(generator, ctx, ndarray_num_elems);
|
||||||
|
}
|
||||||
|
|
||||||
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
|
@ -3547,7 +3536,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None);
|
let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None);
|
||||||
|
|
||||||
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||||
}
|
}
|
||||||
|
@ -3598,3 +3587,90 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
||||||
|
pub fn create_fn_and_call<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
fn_name: &str,
|
||||||
|
ret_type: Option<BasicTypeEnum<'ctx>>,
|
||||||
|
(params, is_var_args): (&[BasicTypeEnum<'ctx>], bool),
|
||||||
|
args: &[BasicValueEnum<'ctx>],
|
||||||
|
call_value_name: Option<&str>,
|
||||||
|
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let intrinsic_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
|
||||||
|
let params = params.iter().copied().map(BasicTypeEnum::into).collect_vec();
|
||||||
|
let fn_type = if let Some(ret_type) = ret_type {
|
||||||
|
ret_type.fn_type(params.as_slice(), is_var_args)
|
||||||
|
} else {
|
||||||
|
ctx.ctx.void_type().fn_type(params.as_slice(), is_var_args)
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.module.add_function(fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(configure) = configure {
|
||||||
|
configure(&intrinsic_fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
let args = args.iter().copied().map(BasicValueEnum::into).collect_vec();
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, args.as_slice(), call_value_name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(Either::left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
||||||
|
///
|
||||||
|
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
|
||||||
|
/// parameters and arguments to be specified as tuples to better indicate the expected type and
|
||||||
|
/// actual value of each parameter-argument pair of the call.
|
||||||
|
pub fn create_and_call_function<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
fn_name: &str,
|
||||||
|
ret_type: Option<BasicTypeEnum<'ctx>>,
|
||||||
|
params: &[(BasicTypeEnum<'ctx>, BasicValueEnum<'ctx>)],
|
||||||
|
value_name: Option<&str>,
|
||||||
|
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let param_tys = params.iter().map(|(ty, _)| ty).copied().map(BasicTypeEnum::into).collect_vec();
|
||||||
|
let arg_values =
|
||||||
|
params.iter().map(|(_, value)| value).copied().map(BasicValueEnum::into).collect_vec();
|
||||||
|
|
||||||
|
create_fn_and_call(
|
||||||
|
ctx,
|
||||||
|
fn_name,
|
||||||
|
ret_type,
|
||||||
|
(param_tys.as_slice(), false),
|
||||||
|
arg_values.as_slice(),
|
||||||
|
value_name,
|
||||||
|
configure,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
||||||
|
///
|
||||||
|
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
|
||||||
|
/// only arguments to be specified and performs inference for the parameter types using
|
||||||
|
/// [`BasicValueEnum::get_type`].
|
||||||
|
pub fn infer_and_call_function<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
fn_name: &str,
|
||||||
|
ret_type: Option<BasicTypeEnum<'ctx>>,
|
||||||
|
args: &[BasicValueEnum<'ctx>],
|
||||||
|
value_name: Option<&str>,
|
||||||
|
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let param_tys = args.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||||
|
|
||||||
|
create_fn_and_call(
|
||||||
|
ctx,
|
||||||
|
fn_name,
|
||||||
|
ret_type,
|
||||||
|
(param_tys.as_slice(), false),
|
||||||
|
args,
|
||||||
|
value_name,
|
||||||
|
configure,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
@ -13,12 +13,11 @@ use super::{CodeGenContext, CodeGenerator};
|
||||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
||||||
pub use list::*;
|
pub use list::*;
|
||||||
pub use math::*;
|
pub use math::*;
|
||||||
pub use ndarray::*;
|
|
||||||
pub use slice::*;
|
pub use slice::*;
|
||||||
|
|
||||||
mod list;
|
mod list;
|
||||||
mod math;
|
mod math;
|
||||||
mod ndarray;
|
pub mod ndarray;
|
||||||
mod slice;
|
mod slice;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
|
@ -60,6 +59,27 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
||||||
|
///
|
||||||
|
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
||||||
|
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &CodeGenContext<'_, '_>,
|
||||||
|
name: &str,
|
||||||
|
) -> String {
|
||||||
|
let mut name = name.to_owned();
|
||||||
|
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
|
32 => {}
|
||||||
|
64 => name.push_str("64"),
|
||||||
|
bit_width => {
|
||||||
|
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name
|
||||||
|
}
|
||||||
|
|
||||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||||
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
||||||
/// NO numeric slice in python.
|
/// NO numeric slice in python.
|
||||||
|
|
|
@ -0,0 +1,258 @@
|
||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
expr::create_and_call_function,
|
||||||
|
irrt::get_usize_dependent_function_name,
|
||||||
|
types::NDArrayType,
|
||||||
|
values::{NDArrayValue, ProxyValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
);
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ndims: IntValue<'ctx>,
|
||||||
|
ndarray_shape: PointerValue<'ctx>,
|
||||||
|
output_ndims: IntValue<'ctx>,
|
||||||
|
output_shape: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_util_assert_output_shape_same",
|
||||||
|
);
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[
|
||||||
|
(llvm_usize.into(), ndarray_ndims.into()),
|
||||||
|
(llvm_pusize.into(), ndarray_shape.into()),
|
||||||
|
(llvm_usize.into(), output_ndims.into()),
|
||||||
|
(llvm_pusize.into(), output_shape.into()),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("size"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("nbytes"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("len"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_i1.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("is_c_contiguous"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_pi8.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
|
||||||
|
Some("pelement"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: PointerValue<'ctx>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name =
|
||||||
|
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_pi8.into()),
|
||||||
|
&[
|
||||||
|
(llvm_ndarray.into(), ndarray.as_base_value().into()),
|
||||||
|
(llvm_pusize.into(), indices.into()),
|
||||||
|
],
|
||||||
|
Some("pelement"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name =
|
||||||
|
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
dst_ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
(llvm_ndarray.into(), src_ndarray.as_base_value().into()),
|
||||||
|
(llvm_ndarray.into(), dst_ndarray.as_base_value().into()),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
|
@ -15,6 +15,9 @@ use crate::codegen::{
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
pub use basic::*;
|
||||||
|
|
||||||
|
mod basic;
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
|
@ -201,6 +201,52 @@ pub fn call_memcpy_generic<'ctx>(
|
||||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `llvm.memcpy` intrinsic.
|
||||||
|
///
|
||||||
|
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
|
||||||
|
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
|
||||||
|
/// Moreover, `len` now refers to the number of elements (rather than bytes) to copy.
|
||||||
|
pub fn call_memcpy_generic_array<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dest: PointerValue<'ctx>,
|
||||||
|
src: PointerValue<'ctx>,
|
||||||
|
len: IntValue<'ctx>,
|
||||||
|
is_volatile: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
|
||||||
|
|
||||||
|
let dest_elem_t = dest.get_type().get_element_type();
|
||||||
|
let src_elem_t = src.get_type().get_element_type();
|
||||||
|
|
||||||
|
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||||
|
dest
|
||||||
|
} else {
|
||||||
|
ctx.builder
|
||||||
|
.build_bit_cast(dest, llvm_p0i8, "")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||||
|
src
|
||||||
|
} else {
|
||||||
|
ctx.builder
|
||||||
|
.build_bit_cast(src, llvm_p0i8, "")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let len = ctx.builder.build_int_cast(len, llvm_sizeof_expr_t, "").unwrap();
|
||||||
|
let len = ctx.builder.build_int_mul(
|
||||||
|
len,
|
||||||
|
src_elem_t.size_of().unwrap(),
|
||||||
|
""
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||||
|
}
|
||||||
|
|
||||||
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
|
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
|
||||||
///
|
///
|
||||||
/// Arguments:
|
/// Arguments:
|
||||||
|
@ -343,3 +389,25 @@ pub fn call_float_powi<'ctx>(
|
||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`llvm.ctpop`](https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic) intrinsic.
|
||||||
|
pub fn call_int_ctpop<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
src: IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
const FN_NAME: &str = "llvm.ctpop";
|
||||||
|
|
||||||
|
let llvm_src_t = src.get_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||||
|
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
|
@ -1119,3 +1119,106 @@ fn gen_in_range_check<'ctx>(
|
||||||
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
||||||
format!("__{}_va_count", &arg_name).into()
|
format!("__{}_va_count", &arg_name).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the alignment of the type.
|
||||||
|
///
|
||||||
|
/// This is necessary as `get_alignment` is not implemented as part of [`BasicType`].
|
||||||
|
pub fn get_type_alignment<'ctx>(ty: impl Into<BasicTypeEnum<'ctx>>) -> IntValue<'ctx> {
|
||||||
|
match ty.into() {
|
||||||
|
BasicTypeEnum::ArrayType(ty) => ty.get_alignment(),
|
||||||
|
BasicTypeEnum::FloatType(ty) => ty.get_alignment(),
|
||||||
|
BasicTypeEnum::IntType(ty) => ty.get_alignment(),
|
||||||
|
BasicTypeEnum::PointerType(ty) => ty.get_alignment(),
|
||||||
|
BasicTypeEnum::StructType(ty) => ty.get_alignment(),
|
||||||
|
BasicTypeEnum::VectorType(ty) => ty.get_alignment(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
|
||||||
|
/// given type.
|
||||||
|
///
|
||||||
|
/// The returned [`PointerValue`] will have a type of `i8*`, a size of at least `size`, and will be
|
||||||
|
/// aligned with the alignment of `align_ty`.
|
||||||
|
pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
align_ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
/// Round `val` up to its modulo `power_of_two`.
|
||||||
|
fn round_up<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
val: IntValue<'ctx>,
|
||||||
|
power_of_two: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
debug_assert_eq!(
|
||||||
|
val.get_type().get_bit_width(),
|
||||||
|
power_of_two.get_type().get_bit_width(),
|
||||||
|
"`val` ({}) and `power_of_two` ({}) must be the same type",
|
||||||
|
val.get_type(),
|
||||||
|
power_of_two.get_type(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let llvm_val_t = val.get_type();
|
||||||
|
|
||||||
|
let max_rem =
|
||||||
|
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
|
||||||
|
ctx.builder
|
||||||
|
.build_and(
|
||||||
|
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
||||||
|
ctx.builder.build_not(max_rem, "").unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let align_ty = align_ty.into();
|
||||||
|
|
||||||
|
let size = ctx.builder.build_int_cast(size, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
debug_assert_eq!(
|
||||||
|
size.get_type().get_bit_width(),
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
"Expected size_t ({}) for parameter `size` of `aligned_alloca`, got {}",
|
||||||
|
llvm_usize,
|
||||||
|
size.get_type(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let alignment = get_type_alignment(align_ty);
|
||||||
|
let alignment = ctx.builder.build_int_cast(alignment, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
alignment_bitcount,
|
||||||
|
alignment_bitcount.get_type().const_int(1, false),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
"0:AssertionError",
|
||||||
|
"Expected power-of-two alignment for aligned_alloca, got {0}",
|
||||||
|
[Some(alignment), None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer_size = round_up(ctx, size, alignment);
|
||||||
|
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
|
||||||
|
|
||||||
|
// Just to be absolutely sure, alloca in [i8 x alignment] slices
|
||||||
|
let buffer = ctx.builder.build_array_alloca(align_ty, aligned_slices, "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_bit_cast(buffer, llvm_pi8, name.unwrap_or_default())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
|
@ -3,14 +3,18 @@ use inkwell::{
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
expr::gen_binop_expr_with_values,
|
expr::gen_binop_expr_with_values,
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
calculate_len_for_slice_range,
|
||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
ndarray::{
|
||||||
|
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
|
||||||
|
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
|
@ -27,7 +31,7 @@ use crate::{
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{arraylike_flatten_element_type, PrimDef},
|
helper::{arraylike_flatten_element_type, PrimDef},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::unpack_ndarray_var_tys,
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -37,25 +41,23 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
|
#[deprecated = "Use NDArrayType::construct_uninitialized instead."]
|
||||||
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_ndarray_t = ctx
|
let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
|
||||||
.get_llvm_type(generator, ndarray_ty)
|
.as_base_type()
|
||||||
.into_pointer_type()
|
|
||||||
.get_element_type()
|
.get_element_type()
|
||||||
.into_struct_type();
|
.into_struct_type();
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
|
|
||||||
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
|
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a dynamic shape.
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
|
@ -83,6 +85,7 @@ where
|
||||||
) -> Result<IntValue<'ctx>, String>,
|
) -> Result<IntValue<'ctx>, String>,
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
// Assert that all dimensions are non-negative
|
// Assert that all dimensions are non-negative
|
||||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
@ -122,10 +125,10 @@ where
|
||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
|
||||||
|
|
||||||
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
|
||||||
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
|
||||||
|
.construct_uninitialized(generator, ctx, num_dims, None);
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
@ -189,28 +192,10 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// TODO: Disallow dim_sz > u32_MAX
|
// TODO: Disallow dim_sz > u32_MAX
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
|
||||||
|
|
||||||
for (i, &shape_dim) in shape.iter().enumerate() {
|
|
||||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
|
||||||
let ndarray_dim = unsafe {
|
|
||||||
ndarray.shape().ptr_offset_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(i as u64, true),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype)
|
||||||
|
.construct_dyn_shape(generator, ctx, shape, None);
|
||||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||||
|
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
|
@ -232,7 +217,9 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
&ndarray.shape().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
(None, None),
|
(None, None),
|
||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
unsafe {
|
||||||
|
ndarray.create_data(generator, ctx, ndarray_num_elems);
|
||||||
|
}
|
||||||
|
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
@ -338,20 +325,24 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
let ndims = shape_tuple.get_type().count_fields();
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
|
||||||
let mut shape = Vec::with_capacity(ndims as usize);
|
let shape = (0..ndims)
|
||||||
for dim_i in 0..ndims {
|
.map(|dim_i| {
|
||||||
let dim = ctx
|
ctx.builder
|
||||||
.builder
|
|
||||||
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.map(|v| {
|
||||||
|
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
|
||||||
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_int_value();
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
shape.push(dim);
|
|
||||||
}
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
}
|
}
|
||||||
BasicValueEnum::IntValue(shape_int) => {
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
let shape_int =
|
||||||
|
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||||
}
|
}
|
||||||
|
@ -505,6 +496,7 @@ where
|
||||||
let lhs_val = NDArrayValue::from_pointer_value(
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -517,6 +509,7 @@ where
|
||||||
let rhs_val = NDArrayValue::from_pointer_value(
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -532,6 +525,7 @@ where
|
||||||
let lhs = NDArrayValue::from_pointer_value(
|
let lhs = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -548,6 +542,7 @@ where
|
||||||
let rhs = NDArrayValue::from_pointer_value(
|
let rhs = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -706,7 +701,8 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
{
|
{
|
||||||
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
||||||
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
|
NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
|
||||||
|
.load_ndims(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
||||||
|
@ -856,7 +852,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
||||||
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
|
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let ndarray = gen_if_else_expr_callback(
|
let ndarray = gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
|
@ -932,6 +928,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
return Ok(NDArrayValue::from_pointer_value(
|
return Ok(NDArrayValue::from_pointer_value(
|
||||||
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
));
|
));
|
||||||
|
@ -1269,6 +1266,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = if slices.is_empty() {
|
let ndarray = if slices.is_empty() {
|
||||||
create_ndarray_dyn_shape(
|
create_ndarray_dyn_shape(
|
||||||
|
@ -1282,8 +1280,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
},
|
},
|
||||||
)?
|
)?
|
||||||
} else {
|
} else {
|
||||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
|
||||||
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
.construct_uninitialized(generator, ctx, this.load_ndims(ctx), None);
|
||||||
|
|
||||||
let ndims = this.load_ndims(ctx);
|
let ndims = this.load_ndims(ctx);
|
||||||
ndarray.create_shape(ctx, llvm_usize, ndims);
|
ndarray.create_shape(ctx, llvm_usize, ndims);
|
||||||
|
@ -1465,6 +1463,7 @@ where
|
||||||
let lhs_val = NDArrayValue::from_pointer_value(
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1473,6 +1472,7 @@ where
|
||||||
let rhs_val = NDArrayValue::from_pointer_value(
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1499,6 +1499,7 @@ where
|
||||||
let ndarray = NDArrayValue::from_pointer_value(
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -2061,6 +2062,7 @@ pub fn gen_ndarray_copy<'ctx>(
|
||||||
NDArrayValue::from_pointer_value(
|
NDArrayValue::from_pointer_value(
|
||||||
this_arg.into_pointer_value(),
|
this_arg.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
|
None,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
@ -2098,7 +2100,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||||
ndarray_fill_flattened(
|
ndarray_fill_flattened(
|
||||||
generator,
|
generator,
|
||||||
context,
|
context,
|
||||||
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|
||||||
|generator, ctx, _| {
|
|generator, ctx, _| {
|
||||||
let value = if value_arg.is_pointer_value() {
|
let value = if value_arg.is_pointer_value() {
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
@ -2140,7 +2142,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
// Dimensions are reversed in the transposed array
|
||||||
|
@ -2260,7 +2262,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
@ -2548,8 +2550,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
||||||
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
|
||||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::{AsContextRef, Context},
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{IntValue, PointerValue},
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
|
@ -12,9 +12,13 @@ use super::{
|
||||||
structure::{StructField, StructFields},
|
structure::{StructField, StructFields},
|
||||||
ProxyType,
|
ProxyType,
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::{
|
||||||
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
codegen::{
|
||||||
|
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
||||||
{CodeGenContext, CodeGenerator},
|
{CodeGenContext, CodeGenerator},
|
||||||
|
},
|
||||||
|
toplevel::numpy::unpack_ndarray_var_tys,
|
||||||
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Proxy type for a `ndarray` type in LLVM.
|
/// Proxy type for a `ndarray` type in LLVM.
|
||||||
|
@ -27,10 +31,14 @@ pub struct NDArrayType<'ctx> {
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
pub struct NDArrayStructFields<'ctx> {
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
|
#[value_type(usize)]
|
||||||
|
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||||
#[value_type(usize)]
|
#[value_type(usize)]
|
||||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
|
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
}
|
}
|
||||||
|
@ -41,70 +49,45 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
llvm_ty: PointerType<'ctx>,
|
llvm_ty: PointerType<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
|
let ctx = llvm_ty.get_context();
|
||||||
|
|
||||||
|
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
|
||||||
|
|
||||||
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||||
};
|
};
|
||||||
if llvm_ndarray_ty.count_fields() != 3 {
|
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Expected 3 fields in `NDArray`, got {}",
|
"Expected {} fields in `NDArray`, got {}",
|
||||||
|
llvm_expected_ty.len(),
|
||||||
llvm_ndarray_ty.count_fields()
|
llvm_ndarray_ty.count_fields()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
llvm_expected_ty
|
||||||
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
.iter()
|
||||||
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
|
.enumerate()
|
||||||
};
|
.map(|(i, expected_ty)| {
|
||||||
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
|
||||||
return Err(format!(
|
})
|
||||||
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
|
.try_for_each(|(expected_ty, actual_ty)| {
|
||||||
llvm_usize.get_bit_width(),
|
if expected_ty == actual_ty {
|
||||||
ndarray_ndims_ty.get_bit_width()
|
Ok(())
|
||||||
));
|
} else {
|
||||||
}
|
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
|
||||||
|
|
||||||
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
|
||||||
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
|
|
||||||
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
|
|
||||||
};
|
|
||||||
let ndarray_dims = ndarray_pdims.get_element_type();
|
|
||||||
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
ndarray_dims.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
|
||||||
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
|
||||||
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
|
||||||
};
|
|
||||||
let ndarray_data = ndarray_pdata.get_element_type();
|
|
||||||
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
if ndarray_data.get_bit_width() != 8 {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
|
||||||
ndarray_data.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this as a member of this Struct
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
fn fields(
|
||||||
|
ctx: impl AsContextRef<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> NDArrayStructFields<'ctx> {
|
||||||
NDArrayStructFields::new(ctx, llvm_usize)
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +95,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn get_fields(
|
pub fn get_fields(
|
||||||
&self,
|
&self,
|
||||||
ctx: &'ctx Context,
|
ctx: impl AsContextRef<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> NDArrayStructFields<'ctx> {
|
) -> NDArrayStructFields<'ctx> {
|
||||||
Self::fields(ctx, llvm_usize)
|
Self::fields(ctx, llvm_usize)
|
||||||
|
@ -120,8 +103,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
pub fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
// struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
|
||||||
//
|
//
|
||||||
// * data : Pointer to an array containing the array data
|
// * data : Pointer to an array containing the array data
|
||||||
// * itemsize: The size of each NDArray elements in bytes
|
// * itemsize: The size of each NDArray elements in bytes
|
||||||
|
@ -147,6 +130,21 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: Type,
|
||||||
|
) -> Self {
|
||||||
|
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
|
||||||
|
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
NDArrayType { ty: Self::llvm_type(ctx.ctx, llvm_usize), dtype: llvm_dtype, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_type(
|
pub fn from_type(
|
||||||
|
@ -165,7 +163,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
self.as_base_type()
|
self.as_base_type()
|
||||||
.get_element_type()
|
.get_element_type()
|
||||||
.into_struct_type()
|
.into_struct_type()
|
||||||
.get_field_type_at_index(0)
|
.get_field_type_at_index(1)
|
||||||
.map(BasicTypeEnum::into_int_type)
|
.map(BasicTypeEnum::into_int_type)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -175,6 +173,119 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
||||||
|
///
|
||||||
|
/// `shape` and `strides` will be automatically allocated onto the stack.
|
||||||
|
///
|
||||||
|
/// The returned ndarray's content will be:
|
||||||
|
/// - `data`: uninitialized.
|
||||||
|
/// - `itemsize`: set to the `sizeof()` of `dtype`.
|
||||||
|
/// - `ndims`: set to the value of `ndims`.
|
||||||
|
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
||||||
|
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
||||||
|
#[must_use]
|
||||||
|
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
// ndims: u64,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let ndarray = self.new_value(generator, ctx, name);
|
||||||
|
|
||||||
|
let itemsize =
|
||||||
|
ctx.builder.build_int_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "").unwrap();
|
||||||
|
ndarray.store_itemsize(ctx, generator, itemsize);
|
||||||
|
|
||||||
|
ndarray.store_ndims(ctx, generator, ndims);
|
||||||
|
|
||||||
|
ndarray.create_shape(ctx, self.llvm_usize, ndims);
|
||||||
|
ndarray.create_strides(ctx, self.llvm_usize, ndims);
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
|
||||||
|
///
|
||||||
|
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||||
|
#[must_use]
|
||||||
|
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: &[u64],
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let ndarray = self.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(shape.len() as u64, false),
|
||||||
|
name,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Write shape
|
||||||
|
let ndarray_shape = ndarray.shape();
|
||||||
|
for (i, dim) in shape.iter().enumerate() {
|
||||||
|
let dim = self.llvm_usize.const_int(*dim, false);
|
||||||
|
unsafe {
|
||||||
|
ndarray_shape.set_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&self.llvm_usize.const_int(i as u64, false),
|
||||||
|
dim,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
|
||||||
|
///
|
||||||
|
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||||
|
#[must_use]
|
||||||
|
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: &[IntValue<'ctx>],
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let ndarray = self.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(shape.len() as u64, false),
|
||||||
|
name,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Write shape
|
||||||
|
let ndarray_shape = ndarray.shape();
|
||||||
|
for (i, dim) in shape.iter().enumerate() {
|
||||||
|
assert_eq!(
|
||||||
|
dim.get_type(),
|
||||||
|
self.llvm_usize,
|
||||||
|
"Expected {} but got {}",
|
||||||
|
self.llvm_usize.print_to_string(),
|
||||||
|
dim.get_type().print_to_string()
|
||||||
|
);
|
||||||
|
unsafe {
|
||||||
|
ndarray_shape.set_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&self.llvm_usize.const_int(i as u64, false),
|
||||||
|
*dim,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||||
|
@ -243,7 +354,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||||
) -> Self::Value {
|
) -> Self::Value {
|
||||||
debug_assert_eq!(value.get_type(), self.as_base_type());
|
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||||
|
|
||||||
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
|
NDArrayValue::from_pointer_value(value, self.dtype, None, self.llvm_usize, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_base_type(&self) -> Self::Base {
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
|
|
@ -145,7 +145,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the value of this field for a given `obj`.
|
/// Sets the value of this field for a given `obj`.
|
||||||
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
|
pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) {
|
||||||
obj.set_field_at_index(self.index, value);
|
obj.set_field_at_index(self.index, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,10 +9,11 @@ use super::{
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt,
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
types::NDArrayType,
|
type_aligned_alloca,
|
||||||
|
types::{structure::StructField, NDArrayType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ use crate::codegen::{
|
||||||
pub struct NDArrayValue<'ctx> {
|
pub struct NDArrayValue<'ctx> {
|
||||||
value: PointerValue<'ctx>,
|
value: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
ndims: Option<u64>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
}
|
}
|
||||||
|
@ -40,12 +42,13 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn from_pointer_value(
|
pub fn from_pointer_value(
|
||||||
ptr: PointerValue<'ctx>,
|
ptr: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
ndims: Option<u64>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
|
@ -75,6 +78,29 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.itemsize
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the size of each element `itemsize` into this instance.
|
||||||
|
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
itemsize: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
self.itemsize(ctx).set(ctx, self.value, itemsize, self.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the size of each element of this `NDArray` as a value.
|
||||||
|
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
self.itemsize(ctx).get(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
/// `getelementptr` on the field.
|
/// `getelementptr` on the field.
|
||||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
@ -105,6 +131,36 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
NDArrayShapeProxy(self)
|
NDArrayShapeProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `stride` array, as if by calling
|
||||||
|
/// `getelementptr` on the field.
|
||||||
|
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.strides
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
|
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.ptr_to_strides(ctx), dims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing the stride with the given `size`.
|
||||||
|
pub fn create_strides(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
self.store_strides(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
NDArrayStridesProxy(self)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
@ -125,21 +181,29 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing data elements with the given element
|
/// Convenience method for creating a new array storing data elements with the given element
|
||||||
/// type `elem_ty` and `size`.
|
/// type `elem_ty` and `size`.
|
||||||
pub fn create_data(
|
///
|
||||||
|
/// The data buffer will be allocated on the stack, and is considered to be owned by this ndarray instance.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `shape` and `itemsize` of the ndarray must be initialized.
|
||||||
|
pub unsafe fn create_data<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
generator: &mut G,
|
||||||
elem_ty: BasicTypeEnum<'ctx>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
size: IntValue<'ctx>,
|
size: IntValue<'ctx>,
|
||||||
) {
|
) {
|
||||||
|
// let itemsize =
|
||||||
|
// ctx.builder.build_int_cast(self.load_itemsize(ctx), size.get_type(), "").unwrap();
|
||||||
let itemsize =
|
let itemsize =
|
||||||
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
|
ctx.builder.build_int_cast(self.dtype.size_of().unwrap(), size.get_type(), "").unwrap();
|
||||||
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
||||||
|
// let nbytes = self.nbytes(generator, ctx);
|
||||||
|
|
||||||
// TODO: What about alignment?
|
let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
|
||||||
self.store_data(
|
self.store_data(ctx, data);
|
||||||
ctx,
|
|
||||||
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
|
// self.set_strides_contiguous(generator, ctx);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||||
|
@ -147,6 +211,133 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
||||||
NDArrayDataProxy(self)
|
NDArrayDataProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Copy shape dimensions from an array.
|
||||||
|
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let num_items = self.load_ndims(ctx);
|
||||||
|
|
||||||
|
call_memcpy_generic_array(
|
||||||
|
ctx,
|
||||||
|
self.shape().base_ptr(ctx, generator),
|
||||||
|
shape,
|
||||||
|
num_items,
|
||||||
|
ctx.ctx.bool_type().const_zero(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy shape dimensions from an ndarray.
|
||||||
|
/// Panics if `ndims` mismatches.
|
||||||
|
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||||
|
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
|
||||||
|
self.copy_shape_from_array(generator, ctx, src_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy strides dimensions from an array.
|
||||||
|
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
strides: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let num_items = self.load_ndims(ctx);
|
||||||
|
|
||||||
|
call_memcpy_generic_array(
|
||||||
|
ctx,
|
||||||
|
self.strides().base_ptr(ctx, generator),
|
||||||
|
strides,
|
||||||
|
num_items,
|
||||||
|
ctx.ctx.bool_type().const_zero(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy strides dimensions from an ndarray.
|
||||||
|
/// Panics if `ndims` mismatches.
|
||||||
|
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||||
|
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
|
||||||
|
self.copy_strides_from_array(generator, ctx, src_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the `np.size()` of this ndarray.
|
||||||
|
pub fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the `ndarray.nbytes` of this ndarray.
|
||||||
|
pub fn nbytes<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the `len()` of this ndarray.
|
||||||
|
pub fn len<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this ndarray is C-contiguous.
|
||||||
|
///
|
||||||
|
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
|
||||||
|
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
||||||
|
///
|
||||||
|
/// Update the ndarray's strides to make the ndarray contiguous.
|
||||||
|
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) {
|
||||||
|
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy data from another ndarray.
|
||||||
|
///
|
||||||
|
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
|
||||||
|
/// do not matter. The copying order is determined by how their flattened views look.
|
||||||
|
///
|
||||||
|
/// Panics if the `dtype`s of ndarrays are different.
|
||||||
|
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
|
||||||
|
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
|
@ -168,103 +359,6 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
|
||||||
#[derive(Copy, Clone)]
|
|
||||||
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn element_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
) -> AnyTypeEnum<'ctx> {
|
|
||||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
self.0.load_ndims(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let size = self.size(ctx, generator);
|
|
||||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
in_range,
|
|
||||||
"0:IndexError",
|
|
||||||
"index {0} is out of bounds for axis 0 with size {1}",
|
|
||||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
|
||||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn downcast_to_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: BasicValueEnum<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
value.into_int_value()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn upcast_from_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
value.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
@ -296,7 +390,12 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
generator: &G,
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
|
irrt::ndarray::call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&self.as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,7 +504,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||||
indices_elem_ty.get_bit_width()
|
indices_elem_ty.get_bit_width()
|
||||||
);
|
);
|
||||||
|
|
||||||
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
||||||
let sizeof_elem = ctx
|
let sizeof_elem = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_truncate_or_bit_cast(
|
.build_int_truncate_or_bit_cast(
|
||||||
|
@ -521,3 +620,197 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
||||||
for NDArrayDataProxy<'ctx, '_>
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.0.load_ndims(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
value.into_int_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: IntValue<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
value.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.strides().base_ptr(ctx, generator).get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.strides")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.ptr_to_strides(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.0.load_ndims(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
value.into_int_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: IntValue<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
value.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1759,14 +1759,14 @@ def run() -> int32:
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
test_ndarray_cholesky()
|
# test_ndarray_cholesky()
|
||||||
test_ndarray_qr()
|
# test_ndarray_qr()
|
||||||
test_ndarray_svd()
|
# test_ndarray_svd()
|
||||||
test_ndarray_linalg_inv()
|
# test_ndarray_linalg_inv()
|
||||||
test_ndarray_pinv()
|
# test_ndarray_pinv()
|
||||||
test_ndarray_matrix_power()
|
# test_ndarray_matrix_power()
|
||||||
test_ndarray_det()
|
# test_ndarray_det()
|
||||||
test_ndarray_lu()
|
# test_ndarray_lu()
|
||||||
test_ndarray_schur()
|
# test_ndarray_schur()
|
||||||
test_ndarray_hessenberg()
|
# test_ndarray_hessenberg()
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue