forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: hold

This commit is contained in:
lyken 2024-08-15 16:16:59 +08:00
parent 1c48d54afa
commit 9fa0dfe202
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
7 changed files with 104 additions and 11 deletions

View File

@ -1,6 +1,7 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
namespace {
/**
@ -14,4 +15,42 @@ struct List {
uint8_t* items;
SizeT len;
};
} // namespace
namespace list {
template <typename SizeT>
void slice_assign(List<SizeT>* dst, List<SizeT>* src, SizeT itemsize,
UserSlice* user_slice) {
Slice slice = user_slice->indices_checked<SizeT>(dst->len);
// NOTE: Python does not have this restriction.
if (slice.len() != src->len) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"List destination has {} item(s), but source has {} "
"item(s). The lengths must match.",
slice.len(), src->len, NO_PARAM);
}
// TODO: Look into how the original implementation was implemented and optimized.
SizeT dst_i = slice.start;
SizeT src_i = 0;
while (src_i < slice.len()) {
__builtin_memcpy(dst->items + dst_i, src->items + src_i, itemsize);
src_i += 1;
dst_i += slice.step;
}
}
} // namespace list
} // namespace
extern "C" {
void __nac3_list_slice_assign(List<int32_t>* dst, List<int32_t>* src,
int32_t itemsize, UserSlice* user_slice) {
list::slice_assign(dst, src, itemsize, user_slice);
}
void __nac3_list_slice_assign64(List<int64_t>* dst, List<int64_t>* src,
int64_t itemsize, UserSlice* user_slice) {
list::slice_assign(dst, src, itemsize, user_slice);
}
}

View File

@ -160,8 +160,8 @@ void index(SizeT num_indexes, const NDIndex* indexes,
} else if (index->type == ND_INDEX_TYPE_SLICE) {
UserSlice* input = (UserSlice*)index->data;
Slice slice;
input->indices_checked<SizeT>(src_ndarray->shape[src_axis], &slice);
Slice slice =
input->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data +=
(SizeT)slice.start * src_ndarray->strides[src_axis];

View File

@ -147,7 +147,9 @@ struct UserSlice {
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT>
void indices_checked(SliceIndex length, Slice* result) {
Slice indices_checked(SliceIndex length) {
// TODO: Switch to `SizeT length`
if (length < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"length should not be negative, got {0}", length,
@ -159,7 +161,7 @@ struct UserSlice {
NO_PARAM, NO_PARAM, NO_PARAM);
}
*result = this->indices(length);
return this->indices(length);
}
};
} // namespace

View File

@ -11,6 +11,7 @@
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/list.hpp>
#include <irrt/ndarray/array.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>

View File

@ -632,8 +632,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
loc: Location,
) {
let param_model = IntModel(Int64);
let params =
params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap()));
let params = params
.map(|p| p.map(|p| param_model.s_extend_or_bit_cast(generator, self, p, "param")));
let err_msg = self.gen_string(generator, err_msg);
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);

View File

@ -5,7 +5,7 @@ mod test;
use super::model::*;
use super::object::ndarray::broadcast::ShapeEntry;
use super::object::ndarray::indexing::NDIndex;
use super::object::ndarray::indexing::{NDIndex, UserSlice};
use super::structure::{List, NDArray, NDIter};
use super::{
classes::{
@ -986,6 +986,23 @@ pub fn setup_irrt_exceptions<'ctx>(
}
}
pub fn call_nac3_list_slice_assign<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
src: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
itemsize: Int<'ctx, SizeT>,
user_slice: Ptr<'ctx, StructModel<UserSlice>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_list_slice_assign");
CallFunction::begin(generator, ctx, &name)
.arg(dst)
.arg(src)
.arg(itemsize)
.arg(user_slice)
.returning_void();
}
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -1,9 +1,15 @@
use crate::{
codegen::{model::*, structure::List, CodeGenContext, CodeGenerator},
codegen::{
irrt::{call_nac3_list_slice_assign, list_slice_assignment},
model::*,
object::ndarray::indexing::UserSlice,
structure::List,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
};
use super::AnyObject;
use super::{ndarray::indexing::RustUserSlice, AnyObject};
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
@ -56,7 +62,8 @@ impl<'ctx> ListObject<'ctx> {
/// Get the value of this [`ListObject`] as a list with opaque items.
///
/// This function allocates on the stack to create the list.
/// This function allocates on the stack to create the list, but the
/// reference to the `items` are preserved.
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -84,4 +91,31 @@ impl<'ctx> ListObject<'ctx> {
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.len, "list_len")
}
pub fn slice_assign_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
user_slice: &RustUserSlice<'ctx>,
source: ListObject<'ctx>,
) {
// Sanity check
assert!(ctx.unifier.unioned(self.item_type, source.item_type));
let user_slice_model = StructModel(UserSlice);
let puser_slice = user_slice_model.alloca(generator, ctx, "user_slice");
user_slice.write_to_user_slice(generator, ctx, puser_slice);
let itemsize = self.instance.model.get_type(generator, ctx.ctx).size_of();
call_nac3_list_slice_assign(
generator,
ctx,
self.get_opaque_list_ptr(generator, ctx),
source.instance.value,
itemsize,
user_slice,
);
todo!()
}
}