Compare commits
No commits in common. "cbcf9678e79cf690f3ad5901c5bbccc6a4685d7e" and "144f0922dbee2c533f55ac9e0a296982d51cc3a2" have entirely different histories.
cbcf9678e7
...
144f0922db
|
@ -735,9 +735,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
||||||
ndarray.create_data(generator, ctx, llvm_elem_ty, num_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||||
let ndarray_data_i8 =
|
let ndarray_data_i8 =
|
||||||
|
@ -1378,7 +1376,6 @@ fn polymorphic_print<'ctx>(
|
||||||
let val = NDArrayValue::from_pointer_value(
|
let val = NDArrayValue::from_pointer_value(
|
||||||
value.into_pointer_value(),
|
value.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
|
@ -3,5 +3,3 @@
|
||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
// TODO: To be deleted since NDArray with strides is done.
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
|
|
@ -1,342 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace basic {
|
|
||||||
/**
|
|
||||||
* @brief Assert that `shape` does not contain negative dimensions.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape to check on
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
if (shape[axis] < 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"negative dimensions are not allowed; axis {0} "
|
|
||||||
"has dimension {1}",
|
|
||||||
axis, shape[axis], NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_output_shape_same(SizeT ndarray_ndims,
|
|
||||||
const SizeT* ndarray_shape,
|
|
||||||
SizeT output_ndims,
|
|
||||||
const SizeT* output_shape) {
|
|
||||||
if (ndarray_ndims != output_ndims) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
|
||||||
output_ndims, ndarray_ndims, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
|
||||||
if (ndarray_shape[axis] != output_shape[axis]) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"Mismatched dimensions on axis {0}, output has "
|
|
||||||
"dimension {1}, but destination ndarray has dimension {2}.",
|
|
||||||
axis, output_shape[axis], ndarray_shape[axis]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an ndarray given its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
||||||
SizeT size = 1;
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++)
|
|
||||||
size *= shape[axis];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of elements in `shape` and `indices`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
|
||||||
* @param nth The index of the element of interest.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = ndims - i - 1;
|
|
||||||
SizeT dim = shape[axis];
|
|
||||||
|
|
||||||
indices[axis] = nth % dim;
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an `ndarray`
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.size`
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT size(const NDArray<SizeT>* ndarray) {
|
|
||||||
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return of the number of its content of an `ndarray`.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.nbytes`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
|
||||||
return size(ndarray) * ndarray->itemsize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.__len__`.
|
|
||||||
*
|
|
||||||
* @param dst_length The length.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT len(const NDArray<SizeT>* ndarray) {
|
|
||||||
if (ndarray->ndims != 0) {
|
|
||||||
return ndarray->shape[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// numpy prohibits `__len__` on unsized objects
|
|
||||||
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
|
||||||
*
|
|
||||||
* You may want to see ndarray's rules for C-contiguity:
|
|
||||||
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
|
||||||
// References:
|
|
||||||
// - tinynumpy's implementation:
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
|
||||||
// - ndarray's flags["C_CONTIGUOUS"]:
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
|
||||||
// - ndarray's rules for C-contiguity:
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
|
|
||||||
// From
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
|
||||||
//
|
|
||||||
// The traditional rule is that for an array to be flagged as C contiguous,
|
|
||||||
// the following must hold:
|
|
||||||
//
|
|
||||||
// strides[-1] == itemsize
|
|
||||||
// strides[i] == shape[i+1] * strides[i + 1]
|
|
||||||
// [...]
|
|
||||||
// According to these rules, a 0- or 1-dimensional array is either both
|
|
||||||
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
|
||||||
// can be C- or F- contiguous, or neither, but not both. Though there
|
|
||||||
// there are exceptions for arrays with zero or one item, in the first
|
|
||||||
// case the check is relaxed up to and including the first dimension
|
|
||||||
// with shape[i] == 0. In the second case `strides == itemsize` will
|
|
||||||
// can be true for all dimensions and both flags are set.
|
|
||||||
|
|
||||||
if (ndarray->ndims == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis_i = ndarray->ndims - i - 1;
|
|
||||||
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
|
||||||
void* element = ndarray->data;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
|
||||||
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
|
||||||
void* element = ndarray->data;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
SizeT dim = ndarray->shape[axis];
|
|
||||||
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
|
||||||
*
|
|
||||||
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
|
||||||
stride_product *= ndarray->shape[axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set an element in `ndarray`.
|
|
||||||
*
|
|
||||||
* @param pelement Pointer to the element in `ndarray` to be set.
|
|
||||||
* @param pvalue Pointer to the value `pelement` will be set to.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
|
||||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
|
||||||
*
|
|
||||||
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
|
||||||
// TODO: Handle overlapping.
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
|
||||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
|
||||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
|
||||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace basic
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::basic;
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
|
||||||
const int32_t* ndarray_shape,
|
|
||||||
int32_t output_ndims,
|
|
||||||
const int32_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
|
||||||
const int64_t* ndarray_shape,
|
|
||||||
int64_t output_ndims,
|
|
||||||
const int64_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief The NDArray object
|
|
||||||
*
|
|
||||||
* Official numpy implementation:
|
|
||||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
|
||||||
*
|
|
||||||
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
|
||||||
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
|
||||||
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
|
||||||
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
|
||||||
* `data`. There are also minor differences in the struct layout.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct NDArray {
|
|
||||||
/**
|
|
||||||
* @brief The number of bytes of a single element in `data`.
|
|
||||||
*/
|
|
||||||
SizeT itemsize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The number of dimensions of this shape.
|
|
||||||
*/
|
|
||||||
SizeT ndims;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The NDArray shape, with length equal to `ndims`.
|
|
||||||
*
|
|
||||||
* Note that it may contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* shape;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Array strides, with length equal to `ndims`
|
|
||||||
*
|
|
||||||
* The stride values are in units of bytes, not number of elements.
|
|
||||||
*
|
|
||||||
* Note that `strides` can have negative values or contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* strides;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The underlying data this `ndarray` is pointing to.
|
|
||||||
*/
|
|
||||||
void* data;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
|
@ -74,7 +74,6 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let arg = NDArrayValue::from_pointer_value(
|
let arg = NDArrayValue::from_pointer_value(
|
||||||
arg.into_pointer_value(),
|
arg.into_pointer_value(),
|
||||||
ctx.get_llvm_type(generator, elem_ty),
|
ctx.get_llvm_type(generator, elem_ty),
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -154,7 +153,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -217,7 +216,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int64,
|
ctx.primitives.int64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -296,7 +295,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint32,
|
ctx.primitives.uint32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -364,7 +363,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint64,
|
ctx.primitives.uint64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -431,7 +430,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -478,7 +477,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -519,7 +518,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -585,7 +584,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| {
|
|generator, ctx, val| {
|
||||||
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
||||||
|
|
||||||
|
@ -640,7 +639,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -691,7 +690,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -922,7 +921,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, None, llvm_usize, None);
|
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx
|
let n_sz_eqz = ctx
|
||||||
|
@ -1136,7 +1135,7 @@ where
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, elem_val| {
|
|generator, ctx, elem_val| {
|
||||||
helper_call_numpy_unary_elementwise(
|
helper_call_numpy_unary_elementwise(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1975,7 +1974,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2017,7 +2016,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2067,7 +2066,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2122,7 +2121,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
@ -2164,7 +2163,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2207,7 +2206,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2260,7 +2259,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
let n2_array = numpy::create_ndarray_const_shape(
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2355,7 +2354,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
@ -2398,7 +2397,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.shape()
|
n1.shape()
|
||||||
|
|
|
@ -1570,14 +1570,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let left_val = NDArrayValue::from_pointer_value(
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
left_val.into_pointer_value(),
|
left_val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype1,
|
llvm_ndarray_dtype1,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let right_val = NDArrayValue::from_pointer_value(
|
let right_val = NDArrayValue::from_pointer_value(
|
||||||
right_val.into_pointer_value(),
|
right_val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype2,
|
llvm_ndarray_dtype2,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1633,7 +1631,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let ndarray_val = NDArrayValue::from_pointer_value(
|
let ndarray_val = NDArrayValue::from_pointer_value(
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
llvm_ndarray_dtype,
|
llvm_ndarray_dtype,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1831,7 +1828,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let val = NDArrayValue::from_pointer_value(
|
let val = NDArrayValue::from_pointer_value(
|
||||||
val.into_pointer_value(),
|
val.into_pointer_value(),
|
||||||
llvm_ndarray_dtype,
|
llvm_ndarray_dtype,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1930,7 +1926,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let left_val = NDArrayValue::from_pointer_value(
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
lhs.into_pointer_value(),
|
lhs.into_pointer_value(),
|
||||||
llvm_ndarray_dtype1,
|
llvm_ndarray_dtype1,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -2804,7 +2799,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
let ndarray = NDArrayValue::from_pointer_value(
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
subscripted_ndarray,
|
subscripted_ndarray,
|
||||||
llvm_ndarray_data_t,
|
llvm_ndarray_data_t,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -2858,9 +2852,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
.builder
|
.builder
|
||||||
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
unsafe {
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
ndarray.create_data(generator, ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
}
|
|
||||||
|
|
||||||
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
|
@ -3555,7 +3547,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let v = NDArrayValue::from_pointer_value(v, llvm_ty, None, usize, None);
|
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None);
|
||||||
|
|
||||||
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||||
}
|
}
|
||||||
|
@ -3606,97 +3598,3 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
|
||||||
pub fn create_fn_and_call<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
fn_name: &str,
|
|
||||||
ret_type: Option<BasicTypeEnum<'ctx>>,
|
|
||||||
(params, is_var_args): (&[BasicTypeEnum<'ctx>], bool),
|
|
||||||
args: &[BasicValueEnum<'ctx>],
|
|
||||||
call_value_name: Option<&str>,
|
|
||||||
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
|
||||||
) -> Option<BasicValueEnum<'ctx>> {
|
|
||||||
let intrinsic_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
|
|
||||||
let params = params.iter()
|
|
||||||
.copied()
|
|
||||||
.map(BasicTypeEnum::into)
|
|
||||||
.collect_vec();
|
|
||||||
let fn_type = if let Some(ret_type) = ret_type {
|
|
||||||
ret_type.fn_type(params.as_slice(), is_var_args)
|
|
||||||
} else {
|
|
||||||
ctx.ctx.void_type().fn_type(params.as_slice(), is_var_args)
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.module.add_function(fn_name, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(configure) = configure {
|
|
||||||
configure(&intrinsic_fn);
|
|
||||||
}
|
|
||||||
|
|
||||||
let args = args.iter()
|
|
||||||
.copied()
|
|
||||||
.map(BasicValueEnum::into)
|
|
||||||
.collect_vec();
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, args.as_slice(), call_value_name.unwrap_or_default())
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(Either::left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
|
||||||
///
|
|
||||||
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
|
|
||||||
/// parameters and arguments to be specified as tuples to better indicate the expected type and
|
|
||||||
/// actual value of each parameter-argument pair of the call.
|
|
||||||
pub fn create_and_call_function<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
fn_name: &str,
|
|
||||||
ret_type: Option<BasicTypeEnum<'ctx>>,
|
|
||||||
params: &[(BasicTypeEnum<'ctx>, BasicValueEnum<'ctx>)],
|
|
||||||
value_name: Option<&str>,
|
|
||||||
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
|
||||||
) -> Option<BasicValueEnum<'ctx>> {
|
|
||||||
let param_tys = params.iter().map(|(ty, _)| ty).copied().map(BasicTypeEnum::into).collect_vec();
|
|
||||||
let arg_values = params.iter().map(|(_, value)| value).copied().map(BasicValueEnum::into).collect_vec();
|
|
||||||
|
|
||||||
create_fn_and_call(
|
|
||||||
ctx,
|
|
||||||
fn_name,
|
|
||||||
ret_type,
|
|
||||||
(param_tys.as_slice(), false),
|
|
||||||
arg_values.as_slice(),
|
|
||||||
value_name,
|
|
||||||
configure,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a function in the current module and inserts a `call` instruction into the LLVM IR.
|
|
||||||
///
|
|
||||||
/// This is a wrapper around [`create_fn_and_call`] for non-vararg function. This function allows
|
|
||||||
/// only arguments to be specified and performs inference for the parameter types using
|
|
||||||
/// [`BasicValueEnum::get_type`].
|
|
||||||
pub fn infer_and_call_function<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
fn_name: &str,
|
|
||||||
ret_type: Option<BasicTypeEnum<'ctx>>,
|
|
||||||
args: &[BasicValueEnum<'ctx>],
|
|
||||||
value_name: Option<&str>,
|
|
||||||
configure: Option<&dyn Fn(&FunctionValue<'ctx>)>,
|
|
||||||
) -> Option<BasicValueEnum<'ctx>> {
|
|
||||||
let param_tys = args.iter()
|
|
||||||
.map(BasicValueEnum::get_type)
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
create_fn_and_call(
|
|
||||||
ctx,
|
|
||||||
fn_name,
|
|
||||||
ret_type,
|
|
||||||
(param_tys.as_slice(), false),
|
|
||||||
args,
|
|
||||||
value_name,
|
|
||||||
configure,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
|
@ -60,27 +60,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
|
||||||
///
|
|
||||||
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
|
||||||
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'_, '_>,
|
|
||||||
name: &str,
|
|
||||||
) -> String {
|
|
||||||
let mut name = name.to_owned();
|
|
||||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
|
||||||
32 => {}
|
|
||||||
64 => name.push_str("64"),
|
|
||||||
bit_width => {
|
|
||||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name
|
|
||||||
}
|
|
||||||
|
|
||||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||||
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
||||||
/// NO numeric slice in python.
|
/// NO numeric slice in python.
|
||||||
|
|
|
@ -15,9 +15,6 @@ use crate::codegen::{
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
pub use basic::*;
|
|
||||||
|
|
||||||
mod basic;
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
|
@ -1,258 +0,0 @@
|
||||||
use inkwell::{
|
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::create_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
types::NDArrayType,
|
|
||||||
values::{NDArrayValue, ProxyValue},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
shape: PointerValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_util_assert_shape_no_negative",
|
|
||||||
);
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray_ndims: IntValue<'ctx>,
|
|
||||||
ndarray_shape: PointerValue<'ctx>,
|
|
||||||
output_ndims: IntValue<'ctx>,
|
|
||||||
output_shape: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_util_assert_output_shape_same",
|
|
||||||
);
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[
|
|
||||||
(llvm_usize.into(), ndarray_ndims.into()),
|
|
||||||
(llvm_pusize.into(), ndarray_shape.into()),
|
|
||||||
(llvm_usize.into(), output_ndims.into()),
|
|
||||||
(llvm_pusize.into(), output_shape.into()),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("size"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("nbytes"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("len"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_i1.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("is_c_contiguous"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
index: IntValue<'ctx>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_pi8.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
|
|
||||||
Some("pelement"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
indices: PointerValue<'ctx>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name =
|
|
||||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_pi8.into()),
|
|
||||||
&[
|
|
||||||
(llvm_ndarray.into(), ndarray.as_base_value().into()),
|
|
||||||
(llvm_pusize.into(), indices.into()),
|
|
||||||
],
|
|
||||||
Some("pelement"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name =
|
|
||||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
dst_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_ndarray = NDArrayType::llvm_type(ctx.ctx, llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
(llvm_ndarray.into(), src_ndarray.as_base_value().into()),
|
|
||||||
(llvm_ndarray.into(), dst_ndarray.as_base_value().into()),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
|
@ -3,7 +3,6 @@ use inkwell::{
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
|
@ -28,7 +27,7 @@ use crate::{
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{arraylike_flatten_element_type, PrimDef},
|
helper::{arraylike_flatten_element_type, PrimDef},
|
||||||
numpy::unpack_ndarray_var_tys,
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -44,16 +43,19 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_ndarray_t = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty)
|
let llvm_ndarray_t = ctx
|
||||||
.as_base_type()
|
.get_llvm_type(generator, ndarray_ty)
|
||||||
|
.into_pointer_type()
|
||||||
.get_element_type()
|
.get_element_type()
|
||||||
.into_struct_type();
|
.into_struct_type();
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
|
|
||||||
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, None, llvm_usize, None))
|
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a dynamic shape.
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
|
@ -187,10 +189,28 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// TODO: Disallow dim_sz > u32_MAX
|
// TODO: Disallow dim_sz > u32_MAX
|
||||||
}
|
}
|
||||||
|
|
||||||
let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||||
|
|
||||||
|
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
||||||
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
|
for (i, &shape_dim) in shape.iter().enumerate() {
|
||||||
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
|
let ndarray_dim = unsafe {
|
||||||
|
ndarray.shape().ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(i as u64, true),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype)
|
|
||||||
.construct_dyn_shape(generator, ctx, shape, None);
|
|
||||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||||
|
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
|
@ -212,9 +232,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
&ndarray.shape().as_slice_value(ctx, generator),
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
(None, None),
|
(None, None),
|
||||||
);
|
);
|
||||||
unsafe {
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
ndarray.create_data(generator, ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
@ -320,24 +338,20 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
let ndims = shape_tuple.get_type().count_fields();
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
|
||||||
let shape = (0..ndims)
|
let mut shape = Vec::with_capacity(ndims as usize);
|
||||||
.map(|dim_i| {
|
for dim_i in 0..ndims {
|
||||||
ctx.builder
|
let dim = ctx
|
||||||
|
.builder
|
||||||
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.map(|v| {
|
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
|
|
||||||
})
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
})
|
.into_int_value();
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
|
shape.push(dim);
|
||||||
|
}
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
}
|
}
|
||||||
BasicValueEnum::IntValue(shape_int) => {
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
let shape_int =
|
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||||
}
|
}
|
||||||
|
@ -491,7 +505,6 @@ where
|
||||||
let lhs_val = NDArrayValue::from_pointer_value(
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -504,7 +517,6 @@ where
|
||||||
let rhs_val = NDArrayValue::from_pointer_value(
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -520,7 +532,6 @@ where
|
||||||
let lhs = NDArrayValue::from_pointer_value(
|
let lhs = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -537,7 +548,6 @@ where
|
||||||
let rhs = NDArrayValue::from_pointer_value(
|
let rhs = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -696,8 +706,7 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
{
|
{
|
||||||
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
||||||
NDArrayValue::from_pointer_value(v, llvm_elem_ty, None, llvm_usize, None)
|
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
|
||||||
.load_ndims(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
||||||
|
@ -847,7 +856,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
||||||
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
|
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let ndarray = gen_if_else_expr_callback(
|
let ndarray = gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
|
@ -923,7 +932,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
return Ok(NDArrayValue::from_pointer_value(
|
return Ok(NDArrayValue::from_pointer_value(
|
||||||
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
));
|
));
|
||||||
|
@ -1457,7 +1465,6 @@ where
|
||||||
let lhs_val = NDArrayValue::from_pointer_value(
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
lhs_val.into_pointer_value(),
|
lhs_val.into_pointer_value(),
|
||||||
llvm_lhs_elem_ty,
|
llvm_lhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1466,7 +1473,6 @@ where
|
||||||
let rhs_val = NDArrayValue::from_pointer_value(
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
rhs_val.into_pointer_value(),
|
rhs_val.into_pointer_value(),
|
||||||
llvm_rhs_elem_ty,
|
llvm_rhs_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1493,7 +1499,6 @@ where
|
||||||
let ndarray = NDArrayValue::from_pointer_value(
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -2056,7 +2061,6 @@ pub fn gen_ndarray_copy<'ctx>(
|
||||||
NDArrayValue::from_pointer_value(
|
NDArrayValue::from_pointer_value(
|
||||||
this_arg.into_pointer_value(),
|
this_arg.into_pointer_value(),
|
||||||
llvm_elem_ty,
|
llvm_elem_ty,
|
||||||
None,
|
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
@ -2094,7 +2098,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||||
ndarray_fill_flattened(
|
ndarray_fill_flattened(
|
||||||
generator,
|
generator,
|
||||||
context,
|
context,
|
||||||
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, None, llvm_usize, None),
|
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, _| {
|
|generator, ctx, _| {
|
||||||
let value = if value_arg.is_pointer_value() {
|
let value = if value_arg.is_pointer_value() {
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
@ -2136,7 +2140,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
// Dimensions are reversed in the transposed array
|
||||||
|
@ -2256,7 +2260,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
@ -2544,8 +2548,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
||||||
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, None, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
||||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, None, llvm_usize, None);
|
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::{AsContextRef, Context},
|
context::Context,
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{IntValue, PointerValue},
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
|
@ -12,13 +12,9 @@ use super::{
|
||||||
structure::{StructField, StructFields},
|
structure::{StructField, StructFields},
|
||||||
ProxyType,
|
ProxyType,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::codegen::{
|
||||||
codegen::{
|
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
||||||
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
|
||||||
{CodeGenContext, CodeGenerator},
|
{CodeGenContext, CodeGenerator},
|
||||||
},
|
|
||||||
toplevel::numpy::unpack_ndarray_var_tys,
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Proxy type for a `ndarray` type in LLVM.
|
/// Proxy type for a `ndarray` type in LLVM.
|
||||||
|
@ -31,14 +27,10 @@ pub struct NDArrayType<'ctx> {
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
pub struct NDArrayStructFields<'ctx> {
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
#[value_type(usize)]
|
|
||||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
#[value_type(usize)]
|
#[value_type(usize)]
|
||||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
|
||||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
}
|
}
|
||||||
|
@ -49,45 +41,70 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
llvm_ty: PointerType<'ctx>,
|
llvm_ty: PointerType<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let ctx = llvm_ty.get_context();
|
|
||||||
|
|
||||||
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
|
|
||||||
|
|
||||||
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||||
};
|
};
|
||||||
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
|
if llvm_ndarray_ty.count_fields() != 3 {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Expected {} fields in `NDArray`, got {}",
|
"Expected 3 fields in `NDArray`, got {}",
|
||||||
llvm_expected_ty.len(),
|
|
||||||
llvm_ndarray_ty.count_fields()
|
llvm_ndarray_ty.count_fields()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_expected_ty
|
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
||||||
.iter()
|
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
||||||
.enumerate()
|
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
|
||||||
.map(|(i, expected_ty)| {
|
};
|
||||||
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
|
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
||||||
})
|
return Err(format!(
|
||||||
.try_for_each(|(expected_ty, actual_ty)| {
|
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
|
||||||
if expected_ty == actual_ty {
|
llvm_usize.get_bit_width(),
|
||||||
Ok(())
|
ndarray_ndims_ty.get_bit_width()
|
||||||
} else {
|
));
|
||||||
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
|
}
|
||||||
|
|
||||||
|
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
||||||
|
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
|
||||||
|
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
|
||||||
|
};
|
||||||
|
let ndarray_dims = ndarray_pdims.get_element_type();
|
||||||
|
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
ndarray_dims.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
||||||
|
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
||||||
|
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
||||||
|
};
|
||||||
|
let ndarray_data = ndarray_pdata.get_element_type();
|
||||||
|
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if ndarray_data.get_bit_width() != 8 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||||
|
ndarray_data.get_bit_width()
|
||||||
|
));
|
||||||
}
|
}
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this into e.g. StructProxyType
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn fields(
|
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
||||||
ctx: impl AsContextRef<'ctx>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
) -> NDArrayStructFields<'ctx> {
|
|
||||||
NDArrayStructFields::new(ctx, llvm_usize)
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +112,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn get_fields(
|
pub fn get_fields(
|
||||||
&self,
|
&self,
|
||||||
ctx: impl AsContextRef<'ctx>,
|
ctx: &'ctx Context,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> NDArrayStructFields<'ctx> {
|
) -> NDArrayStructFields<'ctx> {
|
||||||
Self::fields(ctx, llvm_usize)
|
Self::fields(ctx, llvm_usize)
|
||||||
|
@ -103,8 +120,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
// struct NDArray { data: i8*, itemsize: size_t, ndims: size_t, shape: size_t*, strides: size_t* }
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||||
//
|
//
|
||||||
// * data : Pointer to an array containing the array data
|
// * data : Pointer to an array containing the array data
|
||||||
// * itemsize: The size of each NDArray elements in bytes
|
// * itemsize: The size of each NDArray elements in bytes
|
||||||
|
@ -130,21 +147,6 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
|
||||||
#[must_use]
|
|
||||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ty: Type,
|
|
||||||
) -> Self {
|
|
||||||
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
|
||||||
|
|
||||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
NDArrayType { ty: Self::llvm_type(ctx.ctx, llvm_usize), dtype: llvm_dtype, llvm_usize }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_type(
|
pub fn from_type(
|
||||||
|
@ -163,7 +165,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
self.as_base_type()
|
self.as_base_type()
|
||||||
.get_element_type()
|
.get_element_type()
|
||||||
.into_struct_type()
|
.into_struct_type()
|
||||||
.get_field_type_at_index(1)
|
.get_field_type_at_index(0)
|
||||||
.map(BasicTypeEnum::into_int_type)
|
.map(BasicTypeEnum::into_int_type)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -173,107 +175,6 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
|
||||||
///
|
|
||||||
/// `shape` and `strides` will be automatically allocated onto the stack.
|
|
||||||
///
|
|
||||||
/// The returned ndarray's content will be:
|
|
||||||
/// - `data`: uninitialized.
|
|
||||||
/// - `itemsize`: set to the `sizeof()` of `dtype`.
|
|
||||||
/// - `ndims`: set to the value of `ndims`.
|
|
||||||
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
#[must_use]
|
|
||||||
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndims: u64,
|
|
||||||
name: Option<&'ctx str>,
|
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
|
||||||
let ndarray = self.new_value(generator, ctx, name);
|
|
||||||
|
|
||||||
let itemsize = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_z_extend_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
|
|
||||||
.unwrap();
|
|
||||||
ndarray.store_itemsize(ctx, generator, itemsize);
|
|
||||||
|
|
||||||
let ndims_val = self.llvm_usize.const_int(ndims, false);
|
|
||||||
ndarray.store_ndims(ctx, generator, ndims_val);
|
|
||||||
|
|
||||||
ndarray.create_shape(ctx, self.llvm_usize, ndims_val);
|
|
||||||
ndarray.create_strides(ctx, self.llvm_usize, ndims_val);
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
|
|
||||||
///
|
|
||||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
|
||||||
#[must_use]
|
|
||||||
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
shape: &[u64],
|
|
||||||
name: Option<&'ctx str>,
|
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
|
||||||
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name);
|
|
||||||
|
|
||||||
// Write shape
|
|
||||||
let ndarray_shape = ndarray.shape();
|
|
||||||
for (i, dim) in shape.iter().enumerate() {
|
|
||||||
let dim = self.llvm_usize.const_int(*dim, false);
|
|
||||||
unsafe {
|
|
||||||
ndarray_shape.set_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&self.llvm_usize.const_int(i as u64, false),
|
|
||||||
dim,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
|
|
||||||
///
|
|
||||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
|
||||||
#[must_use]
|
|
||||||
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
shape: &[IntValue<'ctx>],
|
|
||||||
name: Option<&'ctx str>,
|
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
|
||||||
let ndarray = self.construct_uninitialized(generator, ctx, shape.len() as u64, name);
|
|
||||||
|
|
||||||
// Write shape
|
|
||||||
let ndarray_shape = ndarray.shape();
|
|
||||||
for (i, dim) in shape.iter().enumerate() {
|
|
||||||
assert_eq!(
|
|
||||||
dim.get_type(),
|
|
||||||
self.llvm_usize,
|
|
||||||
"Expected {} but got {}",
|
|
||||||
self.llvm_usize.print_to_string(),
|
|
||||||
dim.get_type().print_to_string()
|
|
||||||
);
|
|
||||||
unsafe {
|
|
||||||
ndarray_shape.set_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&self.llvm_usize.const_int(i as u64, false),
|
|
||||||
*dim,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||||
|
@ -342,7 +243,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||||
) -> Self::Value {
|
) -> Self::Value {
|
||||||
debug_assert_eq!(value.get_type(), self.as_base_type());
|
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||||
|
|
||||||
NDArrayValue::from_pointer_value(value, self.dtype, None, self.llvm_usize, name)
|
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_base_type(&self) -> Self::Base {
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
|
|
@ -21,7 +21,6 @@ use crate::codegen::{
|
||||||
pub struct NDArrayValue<'ctx> {
|
pub struct NDArrayValue<'ctx> {
|
||||||
value: PointerValue<'ctx>,
|
value: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
}
|
}
|
||||||
|
@ -41,13 +40,12 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn from_pointer_value(
|
pub fn from_pointer_value(
|
||||||
ptr: PointerValue<'ctx>,
|
ptr: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
|
@ -77,33 +75,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
|
|
||||||
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
self.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.itemsize
|
|
||||||
.ptr_by_gep(ctx, self.value, self.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the size of each element `itemsize` into this instance.
|
|
||||||
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
|
||||||
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the size of each element of this `NDArray` as a value.
|
|
||||||
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
/// `getelementptr` on the field.
|
/// `getelementptr` on the field.
|
||||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
@ -134,36 +105,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
NDArrayShapeProxy(self)
|
NDArrayShapeProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `stride` array, as if by calling
|
|
||||||
/// `getelementptr` on the field.
|
|
||||||
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
self.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.strides
|
|
||||||
.ptr_by_gep(ctx, self.value, self.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
|
||||||
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
|
||||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing the stride with the given `size`.
|
|
||||||
pub fn create_strides(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
NDArrayStridesProxy(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
@ -184,15 +125,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing data elements with the given element
|
/// Convenience method for creating a new array storing data elements with the given element
|
||||||
/// type `elem_ty` and `size`.
|
/// type `elem_ty` and `size`.
|
||||||
///
|
pub fn create_data(
|
||||||
/// The data buffer will be allocated on the stack, and is considered to be owned by this ndarray instance.
|
|
||||||
///
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `shape` and `itemsize` of the ndarray must be initialized.
|
|
||||||
pub unsafe fn create_data<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: BasicTypeEnum<'ctx>,
|
elem_ty: BasicTypeEnum<'ctx>,
|
||||||
size: IntValue<'ctx>,
|
size: IntValue<'ctx>,
|
||||||
|
@ -202,10 +136,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
||||||
|
|
||||||
// TODO: What about alignment?
|
// TODO: What about alignment?
|
||||||
let data = ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap();
|
self.store_data(
|
||||||
self.store_data(ctx, data);
|
ctx,
|
||||||
|
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
|
||||||
// self.set_strides_contiguous(generator, ctx);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||||
|
@ -213,112 +147,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
||||||
NDArrayDataProxy(self)
|
NDArrayDataProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copy shape dimensions from an array.
|
|
||||||
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
shape: PointerValue<'ctx>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy shape dimensions from an ndarray.
|
|
||||||
/// Panics if `ndims` mismatches.
|
|
||||||
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy strides dimensions from an array.
|
|
||||||
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
strides: PointerValue<'ctx>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy strides dimensions from an ndarray.
|
|
||||||
/// Panics if `ndims` mismatches.
|
|
||||||
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `np.size()` of this ndarray.
|
|
||||||
pub fn size<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `ndarray.nbytes` of this ndarray.
|
|
||||||
pub fn nbytes<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `len()` of this ndarray.
|
|
||||||
pub fn len<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this ndarray is C-contiguous.
|
|
||||||
///
|
|
||||||
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
|
|
||||||
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
|
||||||
///
|
|
||||||
/// Update the ndarray's strides to make the ndarray contiguous.
|
|
||||||
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
|
|
||||||
self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy data from another ndarray.
|
|
||||||
///
|
|
||||||
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
|
|
||||||
/// do not matter. The copying order is determined by how their flattened views look.
|
|
||||||
///
|
|
||||||
/// Panics if the `dtype`s of ndarrays are different.
|
|
||||||
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
|
@ -340,6 +168,103 @@ impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.0.load_ndims(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
value.into_int_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: IntValue<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
value.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
@ -596,197 +521,3 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
||||||
for NDArrayDataProxy<'ctx, '_>
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
|
||||||
#[derive(Copy, Clone)]
|
|
||||||
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn element_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
) -> AnyTypeEnum<'ctx> {
|
|
||||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
self.0.load_ndims(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let size = self.size(ctx, generator);
|
|
||||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
in_range,
|
|
||||||
"0:IndexError",
|
|
||||||
"index {0} is out of bounds for axis 0 with size {1}",
|
|
||||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
|
||||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn downcast_to_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: BasicValueEnum<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
value.into_int_value()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
fn upcast_from_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
value.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
|
||||||
#[derive(Copy, Clone)]
|
|
||||||
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
fn element_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
) -> AnyTypeEnum<'ctx> {
|
|
||||||
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
_: &G,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
self.0.load_ndims(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
idx: &IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let size = self.size(ctx, generator);
|
|
||||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
in_range,
|
|
||||||
"0:IndexError",
|
|
||||||
"index {0} is out of bounds for axis 0 with size {1}",
|
|
||||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
|
||||||
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
fn downcast_to_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: BasicValueEnum<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
value.into_int_value()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
fn upcast_from_type(
|
|
||||||
&self,
|
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
value.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1759,14 +1759,14 @@ def run() -> int32:
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
# test_ndarray_cholesky()
|
test_ndarray_cholesky()
|
||||||
# test_ndarray_qr()
|
test_ndarray_qr()
|
||||||
# test_ndarray_svd()
|
test_ndarray_svd()
|
||||||
# test_ndarray_linalg_inv()
|
test_ndarray_linalg_inv()
|
||||||
# test_ndarray_pinv()
|
test_ndarray_pinv()
|
||||||
# test_ndarray_matrix_power()
|
test_ndarray_matrix_power()
|
||||||
# test_ndarray_det()
|
test_ndarray_det()
|
||||||
# test_ndarray_lu()
|
test_ndarray_lu()
|
||||||
# test_ndarray_schur()
|
test_ndarray_schur()
|
||||||
# test_ndarray_hessenberg()
|
test_ndarray_hessenberg()
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue