forked from M-Labs/nac3
core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list. e.g., rejecting `[[1.0, 2.0], [3.0]]`. Previously it was a todo of `np_array()`.
This commit is contained in:
parent
dd1a19d97f
commit
0ec4d13735
|
@ -2,6 +2,7 @@
|
||||||
#include <irrt/int_types.hpp>
|
#include <irrt/int_types.hpp>
|
||||||
#include <irrt/list.hpp>
|
#include <irrt/list.hpp>
|
||||||
#include <irrt/math_util.hpp>
|
#include <irrt/math_util.hpp>
|
||||||
|
#include <irrt/ndarray/array.hpp>
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/debug.hpp>
|
||||||
|
#include <irrt/exception.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/list.hpp>
|
||||||
|
#include <irrt/ndarray/basic.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
namespace ndarray
|
||||||
|
{
|
||||||
|
namespace array
|
||||||
|
{
|
||||||
|
template <typename SizeT>
|
||||||
|
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT> *list, SizeT ndims, SizeT *shape)
|
||||||
|
{
|
||||||
|
if (shape[axis] == -1)
|
||||||
|
{
|
||||||
|
// Dimension is unspecified. Set it.
|
||||||
|
shape[axis] = list->len;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Dimension is specified. Check.
|
||||||
|
if (shape[axis] != list->len)
|
||||||
|
{
|
||||||
|
// Mismatch, throw an error.
|
||||||
|
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"The requested array has an inhomogenous shape "
|
||||||
|
"after {0} dimension(s).",
|
||||||
|
axis, shape[axis], list->len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axis + 1 == ndims)
|
||||||
|
{
|
||||||
|
// `list` has type `list[ItemType]`
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// `list` has type `list[list[...]]`
|
||||||
|
List<SizeT> **lists = (List<SizeT> **)(list->items);
|
||||||
|
for (SizeT i = 0; i < list->len; i++)
|
||||||
|
{
|
||||||
|
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Document me
|
||||||
|
template <typename SizeT> void set_and_validate_list_shape(List<SizeT> *list, SizeT ndims, SizeT *shape)
|
||||||
|
{
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++)
|
||||||
|
{
|
||||||
|
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
|
||||||
|
}
|
||||||
|
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void write_list_to_array_helper(SizeT axis, SizeT *index, List<SizeT> *list, NDArray<SizeT> *ndarray)
|
||||||
|
{
|
||||||
|
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
|
||||||
|
if (IRRT_DEBUG_ASSERT_BOOL)
|
||||||
|
{
|
||||||
|
if (!ndarray::basic::is_c_contiguous(ndarray))
|
||||||
|
{
|
||||||
|
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axis + 1 == ndarray->ndims)
|
||||||
|
{
|
||||||
|
// `list` has type `list[ItemType]`
|
||||||
|
// `ndarray` is contiguous, so we can do this, and this is fast.
|
||||||
|
uint8_t *dst = ndarray->data + (ndarray->itemsize * (*index));
|
||||||
|
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
|
||||||
|
*index += list->len;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// `list` has type `list[list[...]]`
|
||||||
|
List<SizeT> **lists = (List<SizeT> **)(list->items);
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < list->len; i++)
|
||||||
|
{
|
||||||
|
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Document me
|
||||||
|
template <typename SizeT> void write_list_to_array(List<SizeT> *list, NDArray<SizeT> *ndarray)
|
||||||
|
{
|
||||||
|
SizeT index = 0;
|
||||||
|
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
|
||||||
|
}
|
||||||
|
} // namespace array
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
using namespace ndarray::array;
|
||||||
|
|
||||||
|
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t> *list, int32_t ndims, int32_t *shape)
|
||||||
|
{
|
||||||
|
set_and_validate_list_shape(list, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t> *list, int64_t ndims, int64_t *shape)
|
||||||
|
{
|
||||||
|
set_and_validate_list_shape(list, ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_array_write_list_to_array(List<int32_t> *list, NDArray<int32_t> *ndarray)
|
||||||
|
{
|
||||||
|
write_list_to_array(list, ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_array_write_list_to_array64(List<int64_t> *list, NDArray<int64_t> *ndarray)
|
||||||
|
{
|
||||||
|
write_list_to_array(list, ndarray);
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,7 +7,10 @@ use super::{
|
||||||
},
|
},
|
||||||
llvm_intrinsics,
|
llvm_intrinsics,
|
||||||
model::*,
|
model::*,
|
||||||
object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
|
object::{
|
||||||
|
list::List,
|
||||||
|
ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
|
||||||
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||||
|
@ -1136,3 +1139,32 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.arg(dst_ndarray)
|
.arg(dst_ndarray)
|
||||||
.returning_void();
|
.returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
|
||||||
|
ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_array_set_and_validate_list_shape",
|
||||||
|
);
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
|
||||||
|
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_array_write_list_to_array",
|
||||||
|
);
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ use crate::{
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
|
model::*,
|
||||||
object::{
|
object::{
|
||||||
any::AnyObject,
|
any::AnyObject,
|
||||||
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||||
|
@ -21,13 +22,13 @@ use crate::{
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{extract_ndims, PrimDef},
|
helper::extract_ndims,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::Binop,
|
magic_methods::Binop,
|
||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
|
@ -1839,26 +1840,6 @@ pub fn gen_ndarray_array<'ctx>(
|
||||||
assert!(matches!(args.len(), 1..=3));
|
assert!(matches!(args.len(), 1..=3));
|
||||||
|
|
||||||
let obj_ty = fun.0.args[0].ty;
|
let obj_ty = fun.0.args[0].ty;
|
||||||
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
||||||
let mut ty = *params.iter().next().unwrap().1;
|
|
||||||
while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty)
|
|
||||||
{
|
|
||||||
if *obj_id != PrimDef::List.id() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
ty = *params.iter().next().unwrap().1;
|
|
||||||
}
|
|
||||||
ty
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => obj_ty,
|
|
||||||
};
|
|
||||||
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
|
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
|
||||||
|
|
||||||
let copy_arg = if let Some(arg) =
|
let copy_arg = if let Some(arg) =
|
||||||
|
@ -1874,28 +1855,18 @@ pub fn gen_ndarray_array<'ctx>(
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndmin_arg = if let Some(arg) =
|
// The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be
|
||||||
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
// the `ndims` of the function return type.
|
||||||
{
|
let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
let ndmin_ty = fun.0.args[2].ty;
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
|
|
||||||
} else {
|
|
||||||
context.gen_symbol_val(
|
|
||||||
generator,
|
|
||||||
fun.0.args[2].default_value.as_ref().unwrap(),
|
|
||||||
fun.0.args[2].ty,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
call_ndarray_array_impl(
|
let object = AnyObject { value: obj_arg, ty: obj_ty };
|
||||||
generator,
|
// NAC3 booleans are i8.
|
||||||
context,
|
let copy = Int(Bool).truncate(generator, context, copy_arg.into_int_value());
|
||||||
obj_elem_ty,
|
let ndarray = NDArrayObject::make_np_array(generator, context, object, copy)
|
||||||
obj_arg,
|
.atleast_nd(generator, context, ndims);
|
||||||
copy_arg.into_int_value(),
|
|
||||||
ndmin_arg.into_int_value(),
|
Ok(ndarray.instance.value)
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.eye`.
|
/// Generates LLVM IR for `ndarray.eye`.
|
||||||
|
|
|
@ -73,4 +73,35 @@ impl<'ctx> ListObject<'ctx> {
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
) -> Instance<'ctx, Int<SizeT>> {
|
||||||
self.instance.get(generator, ctx, |f| f.len)
|
self.instance.get(generator, ctx, |f| f.len)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the `items` field as an opaque pointer.
|
||||||
|
pub fn get_opaque_items_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
||||||
|
self.instance.get(generator, ctx, |f| f.items).pointer_cast(generator, ctx, Int(Byte))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the value of this [`ListObject`] as a list with opaque items.
|
||||||
|
///
|
||||||
|
/// This function allocates on the stack to create the list, but the
|
||||||
|
/// reference to the `items` are preserved.
|
||||||
|
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>> {
|
||||||
|
let opaque_list = Struct(List { item: Int(Byte) }).alloca(generator, ctx);
|
||||||
|
|
||||||
|
// Copy items pointer
|
||||||
|
let items = self.get_opaque_items_ptr(generator, ctx);
|
||||||
|
opaque_list.set(ctx, |f| f.items, items);
|
||||||
|
|
||||||
|
// Copy len
|
||||||
|
let len = self.instance.get(generator, ctx, |f| f.len);
|
||||||
|
opaque_list.set(ctx, |f| f.len, len);
|
||||||
|
|
||||||
|
opaque_list
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
use super::NDArrayObject;
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::{
|
||||||
|
call_nac3_ndarray_array_set_and_validate_list_shape,
|
||||||
|
call_nac3_ndarray_array_write_list_to_array,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
object::{any::AnyObject, list::ListObject},
|
||||||
|
stmt::gen_if_else_expr_callback,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn get_list_object_dtype_and_ndims<'ctx>(
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: ListObject<'ctx>,
|
||||||
|
) -> (Type, u64) {
|
||||||
|
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
|
||||||
|
|
||||||
|
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
|
||||||
|
let ndims = ndims + 1; // To count `list` itself.
|
||||||
|
|
||||||
|
(dtype, ndims)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
fn make_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: ListObject<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
|
||||||
|
let list_value = list.get_opaque_list_ptr(generator, ctx);
|
||||||
|
|
||||||
|
// Validate `list` has a consistent shape.
|
||||||
|
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
|
||||||
|
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
|
||||||
|
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int);
|
||||||
|
let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
call_nac3_ndarray_array_set_and_validate_list_shape(
|
||||||
|
generator, ctx, list_value, ndims, shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int);
|
||||||
|
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||||
|
ndarray.create_data(generator, ctx);
|
||||||
|
|
||||||
|
// Copy all contents from the list.
|
||||||
|
call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: ListObject<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
// np_array without copying is only possible `list` is not nested.
|
||||||
|
//
|
||||||
|
// If `list` is `list[T]`, we can create an ndarray with `data` set
|
||||||
|
// to the array pointer of `list`.
|
||||||
|
//
|
||||||
|
// If `list` is `list[list[T]]` or worse, copy.
|
||||||
|
|
||||||
|
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
|
||||||
|
if ndims == 1 {
|
||||||
|
// `list` is not nested
|
||||||
|
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1);
|
||||||
|
|
||||||
|
// Set data
|
||||||
|
let data = list.get_opaque_items_ptr(generator, ctx);
|
||||||
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
|
|
||||||
|
// ndarray->shape[0] = list->len;
|
||||||
|
let shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let list_len = list.instance.get(generator, ctx, |f| f.len);
|
||||||
|
shape.set_index_const(ctx, 0, list_len);
|
||||||
|
|
||||||
|
// Set strides, the `data` is contiguous
|
||||||
|
ndarray.set_strides_contiguous(generator, ctx);
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
} else {
|
||||||
|
// `list` is nested, copy
|
||||||
|
NDArrayObject::make_np_array_list_copy_impl(generator, ctx, list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_np_array_list_impl<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
list: ListObject<'ctx>,
|
||||||
|
copy: Instance<'ctx, Int<Bool>>,
|
||||||
|
) -> Self {
|
||||||
|
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
|
||||||
|
|
||||||
|
let ndarray = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_generator, _ctx| Ok(copy.value),
|
||||||
|
|generator, ctx| {
|
||||||
|
let ndarray = NDArrayObject::make_np_array_list_copy_impl(generator, ctx, list);
|
||||||
|
Ok(Some(ndarray.instance.value))
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
let ndarray =
|
||||||
|
NDArrayObject::make_np_array_list_try_no_copy_impl(generator, ctx, list);
|
||||||
|
Ok(Some(ndarray.instance.value))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
copy: Instance<'ctx, Int<Bool>>,
|
||||||
|
) -> Self {
|
||||||
|
let ndarray_val = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_generator, _ctx| Ok(copy.value),
|
||||||
|
|generator, ctx| {
|
||||||
|
let ndarray = ndarray.make_copy(generator, ctx); // Force copy
|
||||||
|
Ok(Some(ndarray.instance.value))
|
||||||
|
},
|
||||||
|
|_generator, _ctx| {
|
||||||
|
// No need to copy. Return `ndarray` itself.
|
||||||
|
Ok(Some(ndarray.instance.value))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
NDArrayObject::from_value_and_unpacked_types(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_val,
|
||||||
|
ndarray.dtype,
|
||||||
|
ndarray.ndims,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ndarray like `np.array()`.
|
||||||
|
///
|
||||||
|
/// NOTE: The `ndmin` argument is not here. You may want to
|
||||||
|
/// do [`NDArrayObject::atleast_nd`] to achieve that.
|
||||||
|
pub fn make_np_array<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
object: AnyObject<'ctx>,
|
||||||
|
copy: Instance<'ctx, Int<Bool>>,
|
||||||
|
) -> Self {
|
||||||
|
match &*ctx.unifier.get_ty(object.ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let list = ListObject::from_object(generator, ctx, object);
|
||||||
|
NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy)
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||||
|
NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy)
|
||||||
|
}
|
||||||
|
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod array;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
|
@ -74,8 +75,19 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
) -> NDArrayObject<'ctx> {
|
) -> NDArrayObject<'ctx> {
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
|
||||||
|
}
|
||||||
|
|
||||||
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap();
|
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
|
||||||
|
/// `dtype` and `ndims`.
|
||||||
|
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: V,
|
||||||
|
dtype: Type,
|
||||||
|
ndims: u64,
|
||||||
|
) -> Self {
|
||||||
|
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap();
|
||||||
NDArrayObject { dtype, ndims, instance: value }
|
NDArrayObject { dtype, ndims, instance: value }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue