Implement abstractions over Structs and NDArray #554

Merged
sb10q merged 16 commits from ndstrides-neo into master 2024-11-21 18:16:28 +08:00
8 changed files with 33 additions and 23 deletions
Showing only changes of commit b58c99369e - Show all commits

View File

@ -1,5 +1,4 @@
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"

View File

@ -4,6 +4,6 @@
template<typename SizeT>
struct CSlice {
uint8_t* base;
void* base;
SizeT len;
};

View File

@ -6,7 +6,7 @@
/**
* @brief The int type of ARTIQ exception IDs.
*/
typedef int32_t ExceptionId;
using ExceptionId = int32_t;
/*
* Set of exceptions C++ IRRT can use.
@ -55,14 +55,14 @@ void _raise_exception_helper(ExceptionId id,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(filename)),
.len = static_cast<int32_t>(__builtin_strlen(filename))},
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
.len = static_cast<SizeT>(__builtin_strlen(filename))},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(function)),
.len = static_cast<int32_t>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(msg)),
.len = static_cast<int32_t>(__builtin_strlen(msg))},
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
.len = static_cast<SizeT>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
.len = static_cast<SizeT>(__builtin_strlen(msg))},
};
e.params[0] = param0;
e.params[1] = param1;
@ -70,6 +70,7 @@ void _raise_exception_helper(ExceptionId id,
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
} // namespace
/**
* @brief Raise an exception with location details (location in the IRRT source files).
@ -82,4 +83,3 @@ void _raise_exception_helper(ExceptionId id,
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
} // namespace

View File

@ -8,12 +8,17 @@ using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-type"
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#pragma clang diagnostic pop
#endif
// NDArray indices are always `uint32_t`.

View File

@ -13,12 +13,12 @@ extern "C" {
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t* dest_arr,
void* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t* src_arr,
void* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
@ -29,11 +29,13 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
@ -44,7 +46,7 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
void* tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
@ -53,20 +55,24 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}

View File

@ -90,4 +90,4 @@ double __nac3_j0(double x) {
return j0(x);
}
}
} // namespace

View File

@ -141,4 +141,4 @@ void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}
} // namespace

View File

@ -25,4 +25,4 @@ SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end,
return 0;
}
}
}
} // namespace