From 9fa0dfe202abd0a5dacb32a68da89a670738a72a Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 15 Aug 2024 16:16:59 +0800 Subject: [PATCH] WIP: core/ndstrides: hold --- nac3core/irrt/irrt/list.hpp | 41 ++++++++++++++++++++++++- nac3core/irrt/irrt/ndarray/indexing.hpp | 4 +-- nac3core/irrt/irrt/slice.hpp | 6 ++-- nac3core/irrt/irrt_everything.hpp | 1 + nac3core/src/codegen/expr.rs | 4 +-- nac3core/src/codegen/irrt/mod.rs | 19 +++++++++++- nac3core/src/codegen/object/list.rs | 40 ++++++++++++++++++++++-- 7 files changed, 104 insertions(+), 11 deletions(-) diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index c07ad6ac..a3ffbaed 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace { /** @@ -14,4 +15,42 @@ struct List { uint8_t* items; SizeT len; }; -} // namespace \ No newline at end of file + +namespace list { +template +void slice_assign(List* dst, List* src, SizeT itemsize, + UserSlice* user_slice) { + Slice slice = user_slice->indices_checked(dst->len); + + // NOTE: Python does not have this restriction. + if (slice.len() != src->len) { + raise_exception(SizeT, EXN_VALUE_ERROR, + "List destination has {} item(s), but source has {} " + "item(s). The lengths must match.", + slice.len(), src->len, NO_PARAM); + } + + // TODO: Look into how the original implementation was implemented and optimized. + SizeT dst_i = slice.start; + SizeT src_i = 0; + while (src_i < slice.len()) { + __builtin_memcpy(dst->items + dst_i, src->items + src_i, itemsize); + + src_i += 1; + dst_i += slice.step; + } +} +} // namespace list +} // namespace + +extern "C" { +void __nac3_list_slice_assign(List* dst, List* src, + int32_t itemsize, UserSlice* user_slice) { + list::slice_assign(dst, src, itemsize, user_slice); +} + +void __nac3_list_slice_assign64(List* dst, List* src, + int64_t itemsize, UserSlice* user_slice) { + list::slice_assign(dst, src, itemsize, user_slice); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index 94d47c83..e1fb0c2b 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -160,8 +160,8 @@ void index(SizeT num_indexes, const NDIndex* indexes, } else if (index->type == ND_INDEX_TYPE_SLICE) { UserSlice* input = (UserSlice*)index->data; - Slice slice; - input->indices_checked(src_ndarray->shape[src_axis], &slice); + Slice slice = + input->indices_checked(src_ndarray->shape[src_axis]); dst_ndarray->data += (SizeT)slice.start * src_ndarray->strides[src_axis]; diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp index 1ed2d8c4..bb20fe5e 100644 --- a/nac3core/irrt/irrt/slice.hpp +++ b/nac3core/irrt/irrt/slice.hpp @@ -147,7 +147,9 @@ struct UserSlice { * @brief Like `.indices()` but with assertions. */ template - void indices_checked(SliceIndex length, Slice* result) { + Slice indices_checked(SliceIndex length) { + // TODO: Switch to `SizeT length` + if (length < 0) { raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, @@ -159,7 +161,7 @@ struct UserSlice { NO_PARAM, NO_PARAM, NO_PARAM); } - *result = this->indices(length); + return this->indices(length); } }; } // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 63534e26..1cc541ea 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 7eaa0e9b..897b9650 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -632,8 +632,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { loc: Location, ) { let param_model = IntModel(Int64); - let params = - params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap())); + let params = params + .map(|p| p.map(|p| param_model.s_extend_or_bit_cast(generator, self, p, "param"))); let err_msg = self.gen_string(generator, err_msg); self.make_assert_impl(generator, cond, err_name, err_msg, params, loc); diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d92821a7..ad450cad 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -5,7 +5,7 @@ mod test; use super::model::*; use super::object::ndarray::broadcast::ShapeEntry; -use super::object::ndarray::indexing::NDIndex; +use super::object::ndarray::indexing::{NDIndex, UserSlice}; use super::structure::{List, NDArray, NDIter}; use super::{ classes::{ @@ -986,6 +986,23 @@ pub fn setup_irrt_exceptions<'ctx>( } } +pub fn call_nac3_list_slice_assign<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst: Ptr<'ctx, StructModel>>>, + src: Ptr<'ctx, StructModel>>>, + itemsize: Int<'ctx, SizeT>, + user_slice: Ptr<'ctx, StructModel>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_list_slice_assign"); + CallFunction::begin(generator, ctx, &name) + .arg(dst) + .arg(src) + .arg(itemsize) + .arg(user_slice) + .returning_void(); +} + pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/object/list.rs b/nac3core/src/codegen/object/list.rs index 0821e70c..d3ca6a0a 100644 --- a/nac3core/src/codegen/object/list.rs +++ b/nac3core/src/codegen/object/list.rs @@ -1,9 +1,15 @@ use crate::{ - codegen::{model::*, structure::List, CodeGenContext, CodeGenerator}, + codegen::{ + irrt::{call_nac3_list_slice_assign, list_slice_assignment}, + model::*, + object::ndarray::indexing::UserSlice, + structure::List, + CodeGenContext, CodeGenerator, + }, typecheck::typedef::{iter_type_vars, Type, TypeEnum}, }; -use super::AnyObject; +use super::{ndarray::indexing::RustUserSlice, AnyObject}; /// A NAC3 Python List object. #[derive(Debug, Clone, Copy)] @@ -56,7 +62,8 @@ impl<'ctx> ListObject<'ctx> { /// Get the value of this [`ListObject`] as a list with opaque items. /// - /// This function allocates on the stack to create the list. + /// This function allocates on the stack to create the list, but the + /// reference to the `items` are preserved. pub fn get_opaque_list_ptr( &self, generator: &mut G, @@ -84,4 +91,31 @@ impl<'ctx> ListObject<'ctx> { ) -> Int<'ctx, SizeT> { self.instance.get(generator, ctx, |f| f.len, "list_len") } + + pub fn slice_assign_from( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + user_slice: &RustUserSlice<'ctx>, + source: ListObject<'ctx>, + ) { + // Sanity check + assert!(ctx.unifier.unioned(self.item_type, source.item_type)); + + let user_slice_model = StructModel(UserSlice); + let puser_slice = user_slice_model.alloca(generator, ctx, "user_slice"); + user_slice.write_to_user_slice(generator, ctx, puser_slice); + + let itemsize = self.instance.model.get_type(generator, ctx.ctx).size_of(); + + call_nac3_list_slice_assign( + generator, + ctx, + self.get_opaque_list_ptr(generator, ctx), + source.instance.value, + itemsize, + user_slice, + ); + todo!() + } }