forked from M-Labs/nac3
WIP: core/ndstrides: hold
This commit is contained in:
parent
1c48d54afa
commit
9fa0dfe202
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/**
|
/**
|
||||||
|
@ -14,4 +15,42 @@ struct List {
|
||||||
uint8_t* items;
|
uint8_t* items;
|
||||||
SizeT len;
|
SizeT len;
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
namespace list {
|
||||||
|
template <typename SizeT>
|
||||||
|
void slice_assign(List<SizeT>* dst, List<SizeT>* src, SizeT itemsize,
|
||||||
|
UserSlice* user_slice) {
|
||||||
|
Slice slice = user_slice->indices_checked<SizeT>(dst->len);
|
||||||
|
|
||||||
|
// NOTE: Python does not have this restriction.
|
||||||
|
if (slice.len() != src->len) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"List destination has {} item(s), but source has {} "
|
||||||
|
"item(s). The lengths must match.",
|
||||||
|
slice.len(), src->len, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Look into how the original implementation was implemented and optimized.
|
||||||
|
SizeT dst_i = slice.start;
|
||||||
|
SizeT src_i = 0;
|
||||||
|
while (src_i < slice.len()) {
|
||||||
|
__builtin_memcpy(dst->items + dst_i, src->items + src_i, itemsize);
|
||||||
|
|
||||||
|
src_i += 1;
|
||||||
|
dst_i += slice.step;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace list
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void __nac3_list_slice_assign(List<int32_t>* dst, List<int32_t>* src,
|
||||||
|
int32_t itemsize, UserSlice* user_slice) {
|
||||||
|
list::slice_assign(dst, src, itemsize, user_slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_list_slice_assign64(List<int64_t>* dst, List<int64_t>* src,
|
||||||
|
int64_t itemsize, UserSlice* user_slice) {
|
||||||
|
list::slice_assign(dst, src, itemsize, user_slice);
|
||||||
|
}
|
||||||
|
}
|
|
@ -160,8 +160,8 @@ void index(SizeT num_indexes, const NDIndex* indexes,
|
||||||
} else if (index->type == ND_INDEX_TYPE_SLICE) {
|
} else if (index->type == ND_INDEX_TYPE_SLICE) {
|
||||||
UserSlice* input = (UserSlice*)index->data;
|
UserSlice* input = (UserSlice*)index->data;
|
||||||
|
|
||||||
Slice slice;
|
Slice slice =
|
||||||
input->indices_checked<SizeT>(src_ndarray->shape[src_axis], &slice);
|
input->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
||||||
|
|
||||||
dst_ndarray->data +=
|
dst_ndarray->data +=
|
||||||
(SizeT)slice.start * src_ndarray->strides[src_axis];
|
(SizeT)slice.start * src_ndarray->strides[src_axis];
|
||||||
|
|
|
@ -147,7 +147,9 @@ struct UserSlice {
|
||||||
* @brief Like `.indices()` but with assertions.
|
* @brief Like `.indices()` but with assertions.
|
||||||
*/
|
*/
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
void indices_checked(SliceIndex length, Slice* result) {
|
Slice indices_checked(SliceIndex length) {
|
||||||
|
// TODO: Switch to `SizeT length`
|
||||||
|
|
||||||
if (length < 0) {
|
if (length < 0) {
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
"length should not be negative, got {0}", length,
|
"length should not be negative, got {0}", length,
|
||||||
|
@ -159,7 +161,7 @@ struct UserSlice {
|
||||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
}
|
}
|
||||||
|
|
||||||
*result = this->indices(length);
|
return this->indices(length);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
|
@ -11,6 +11,7 @@
|
||||||
#include <irrt/debug.hpp>
|
#include <irrt/debug.hpp>
|
||||||
#include <irrt/exception.hpp>
|
#include <irrt/exception.hpp>
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/list.hpp>
|
||||||
#include <irrt/ndarray/array.hpp>
|
#include <irrt/ndarray/array.hpp>
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/broadcast.hpp>
|
#include <irrt/ndarray/broadcast.hpp>
|
||||||
|
|
|
@ -632,8 +632,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
loc: Location,
|
loc: Location,
|
||||||
) {
|
) {
|
||||||
let param_model = IntModel(Int64);
|
let param_model = IntModel(Int64);
|
||||||
let params =
|
let params = params
|
||||||
params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap()));
|
.map(|p| p.map(|p| param_model.s_extend_or_bit_cast(generator, self, p, "param")));
|
||||||
|
|
||||||
let err_msg = self.gen_string(generator, err_msg);
|
let err_msg = self.gen_string(generator, err_msg);
|
||||||
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
|
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
|
||||||
|
|
|
@ -5,7 +5,7 @@ mod test;
|
||||||
|
|
||||||
use super::model::*;
|
use super::model::*;
|
||||||
use super::object::ndarray::broadcast::ShapeEntry;
|
use super::object::ndarray::broadcast::ShapeEntry;
|
||||||
use super::object::ndarray::indexing::NDIndex;
|
use super::object::ndarray::indexing::{NDIndex, UserSlice};
|
||||||
use super::structure::{List, NDArray, NDIter};
|
use super::structure::{List, NDArray, NDIter};
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
|
@ -986,6 +986,23 @@ pub fn setup_irrt_exceptions<'ctx>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_list_slice_assign<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dst: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
|
||||||
|
src: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
|
||||||
|
itemsize: Int<'ctx, SizeT>,
|
||||||
|
user_slice: Ptr<'ctx, StructModel<UserSlice>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_list_slice_assign");
|
||||||
|
CallFunction::begin(generator, ctx, &name)
|
||||||
|
.arg(dst)
|
||||||
|
.arg(src)
|
||||||
|
.arg(itemsize)
|
||||||
|
.arg(user_slice)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
|
@ -1,9 +1,15 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{model::*, structure::List, CodeGenContext, CodeGenerator},
|
codegen::{
|
||||||
|
irrt::{call_nac3_list_slice_assign, list_slice_assignment},
|
||||||
|
model::*,
|
||||||
|
object::ndarray::indexing::UserSlice,
|
||||||
|
structure::List,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::AnyObject;
|
use super::{ndarray::indexing::RustUserSlice, AnyObject};
|
||||||
|
|
||||||
/// A NAC3 Python List object.
|
/// A NAC3 Python List object.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
@ -56,7 +62,8 @@ impl<'ctx> ListObject<'ctx> {
|
||||||
|
|
||||||
/// Get the value of this [`ListObject`] as a list with opaque items.
|
/// Get the value of this [`ListObject`] as a list with opaque items.
|
||||||
///
|
///
|
||||||
/// This function allocates on the stack to create the list.
|
/// This function allocates on the stack to create the list, but the
|
||||||
|
/// reference to the `items` are preserved.
|
||||||
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
|
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -84,4 +91,31 @@ impl<'ctx> ListObject<'ctx> {
|
||||||
) -> Int<'ctx, SizeT> {
|
) -> Int<'ctx, SizeT> {
|
||||||
self.instance.get(generator, ctx, |f| f.len, "list_len")
|
self.instance.get(generator, ctx, |f| f.len, "list_len")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn slice_assign_from<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
user_slice: &RustUserSlice<'ctx>,
|
||||||
|
source: ListObject<'ctx>,
|
||||||
|
) {
|
||||||
|
// Sanity check
|
||||||
|
assert!(ctx.unifier.unioned(self.item_type, source.item_type));
|
||||||
|
|
||||||
|
let user_slice_model = StructModel(UserSlice);
|
||||||
|
let puser_slice = user_slice_model.alloca(generator, ctx, "user_slice");
|
||||||
|
user_slice.write_to_user_slice(generator, ctx, puser_slice);
|
||||||
|
|
||||||
|
let itemsize = self.instance.model.get_type(generator, ctx.ctx).size_of();
|
||||||
|
|
||||||
|
call_nac3_list_slice_assign(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
self.get_opaque_list_ptr(generator, ctx),
|
||||||
|
source.instance.value,
|
||||||
|
itemsize,
|
||||||
|
user_slice,
|
||||||
|
);
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue