forked from M-Labs/nac3
WIP: core/ndstrides: on iter
This commit is contained in:
parent
fd78f7a0e8
commit
15dfb2eaa0
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
|
||||
namespace {
|
||||
/**
|
||||
|
@ -23,15 +24,24 @@ template <typename SizeT>
|
|||
struct IndicesIter {
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
SizeT* indices;
|
||||
SizeT* strides;
|
||||
SizeT size; // Product of shape
|
||||
SizeT nth; // The nth (0-based) index of the current indices.
|
||||
|
||||
IndicesIter(SizeT ndims, SizeT* shape, SizeT* indices) {
|
||||
SizeT* indices; // The current indices
|
||||
SizeT nth; // The nth (0-based) index of the current indices.
|
||||
uint8_t* element; // The current element
|
||||
|
||||
// A convenient constructor for internal C++ IRRT.
|
||||
IndicesIter(SizeT ndims, SizeT* shape, SizeT* strides, SizeT *indices, uint8_t* element) {
|
||||
this->ndims = ndims;
|
||||
this->shape = shape;
|
||||
this->strides = strides;
|
||||
this->indices = indices;
|
||||
this->element = element;
|
||||
this->initialize();
|
||||
}
|
||||
|
||||
void initialize() {
|
||||
reset();
|
||||
|
||||
this->size = 1;
|
||||
|
@ -45,7 +55,7 @@ struct IndicesIter {
|
|||
nth = 0;
|
||||
}
|
||||
|
||||
bool ok() { return nth < size; }
|
||||
bool has_next() { return nth < size; }
|
||||
|
||||
void next() {
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
|
@ -60,4 +70,35 @@ struct IndicesIter {
|
|||
nth++;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
void __call_nac3_ndarray_indices_iter_initialize(IndicesIter<int32_t>* iter,
|
||||
int32_t ndims, int32_t* shape,
|
||||
int32_t* indices) {
|
||||
iter->initialize(ndims, shape, indices);
|
||||
}
|
||||
|
||||
void __call_nac3_ndarray_indices_iter_initialize64(IndicesIter<int64_t>* iter,
|
||||
int64_t ndims,
|
||||
int64_t* shape,
|
||||
int64_t* indices) {
|
||||
iter->initialize(ndims, shape, indices);
|
||||
}
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_has_next(IndicesIter<int32_t>* iter) {
|
||||
iter->has_next();
|
||||
}
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_has_next64(IndicesIter<int64_t>* iter) {
|
||||
iter->has_next();
|
||||
}
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_next(IndicesIter<int32_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_next64(IndicesIter<int64_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
}
|
|
@ -125,7 +125,7 @@ void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
|
|||
SizeT* mat_indices = indices + u;
|
||||
IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices);
|
||||
|
||||
for (; iter.ok(); iter.next()) {
|
||||
for (; iter.has_next(); iter.next()) {
|
||||
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
|
||||
for (SizeT j = 0; j < dst_mat_shape[1]; j++) {
|
||||
// `indices` is being reused to index into different ndarrays.
|
||||
|
|
|
@ -1238,3 +1238,8 @@ pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
|||
get_sizet_dependent_function_name(generator, ctx, "__nac3_array_write_list_to_array");
|
||||
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_indices_iter_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
) {
|
||||
}
|
|
@ -218,3 +218,5 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
|
|||
self.gep(ctx, get_field).store(ctx, value);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add an opaque struct type?
|
|
@ -20,61 +20,6 @@ use super::{
|
|||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the one value in `np.ones()` of a `dtype`.
|
||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to create an ndarray with uninitialized values.
|
||||
///
|
||||
|
@ -313,37 +258,9 @@ pub fn gen_ndarray_arange<'ctx>(
|
|||
let input_ty = fun.0.args[0].ty;
|
||||
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?.into_int_value();
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Process input
|
||||
let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim");
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
let ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
1, // ndims = 1
|
||||
"arange_ndarray",
|
||||
);
|
||||
|
||||
// `ndarray.shape[0] = input`
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
ndarray
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.shape, "shape")
|
||||
.offset(generator, ctx, zero.value, "dim")
|
||||
.store(ctx, input);
|
||||
|
||||
// Create data and set elements
|
||||
ndarray.create_data(generator, ctx);
|
||||
ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
||||
let val =
|
||||
ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap();
|
||||
ctx.builder.build_store(pelement, val).unwrap();
|
||||
Ok(())
|
||||
})?;
|
||||
// Implementation
|
||||
let input_dim = IntModel(SizeT).s_extend_or_bit_cast(generator, ctx, input, "input_dim");
|
||||
let ndarray = NDArrayObject::from_np_arange(generator, ctx, input_dim);
|
||||
|
||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
@ -386,23 +303,7 @@ pub fn gen_ndarray_shape<'ctx>(
|
|||
|
||||
// Process ndarray
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||
|
||||
for i in 0..ndarray.ndims {
|
||||
let dim = ndarray
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.shape, "")
|
||||
.offset_const(generator, ctx, i, "")
|
||||
.load(generator, ctx, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
objects
|
||||
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
|
||||
}
|
||||
|
||||
let shape = TupleObject::create(generator, ctx, objects, "shape");
|
||||
Ok(shape.value.as_basic_value_enum())
|
||||
Ok(ndarray.make_shape_tuple(generator, ctx).value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `<ndarray>.strides`.
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
use inkwell::values::BasicValueEnum;
|
||||
|
||||
use super::{scalar::ScalarObject, NDArrayObject};
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
|
||||
CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the one value in `np.ones()` of a `dtype`.
|
||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Create an ndarray like `np.empty`.
|
||||
pub fn from_np_empty<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
ndims: u64,
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
// Validate `shape`
|
||||
// TODO: Should the caller be responsible for this instead?
|
||||
let ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, ndims);
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
|
||||
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "full_ndarray");
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||
ndarray.create_data(generator, ctx);
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.full`.
|
||||
pub fn from_np_full<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
ndims: u64,
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
fill_value: ScalarObject<'ctx>,
|
||||
) -> Self {
|
||||
// Sanity check on `fill_value`'s dtype.
|
||||
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
|
||||
|
||||
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
|
||||
ndarray.fill(generator, ctx, fill_value.value);
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.zero`.
|
||||
pub fn from_np_zero<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
ndims: u64,
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.ones`.
|
||||
pub fn from_np_ones<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
ndims: u64,
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.arange`.
|
||||
/// The returned ndarray's `dtype` is always `float`
|
||||
pub fn from_np_arange<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
length: Int<'ctx, SizeT>,
|
||||
) -> Self {
|
||||
let ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
1, // ndims = 1
|
||||
"arange_ndarray",
|
||||
);
|
||||
|
||||
// `ndarray.shape[0] = length`
|
||||
ndarray
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.shape, "shape")
|
||||
.offset_const(generator, ctx, 0, "dim")
|
||||
.store(ctx, length);
|
||||
|
||||
// Create data and set elements
|
||||
ndarray.create_data(generator, ctx);
|
||||
ndarray
|
||||
.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val")
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pelement, val).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.eye`.
|
||||
pub fn from_np_eye<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
rows: Int<'ctx, SizeT>,
|
||||
cols: Int<'ctx, SizeT>,
|
||||
diagonal: Int<'ctx, SizeT>,
|
||||
) -> Self {
|
||||
let ndarray = NDArrayObject::alloca_dynamic_shape(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
&[rows, cols],
|
||||
"eye_ndarray",
|
||||
);
|
||||
ndarray
|
||||
.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||
// and this loop would not execute.
|
||||
|
||||
todo!()
|
||||
})
|
||||
.unwrap();
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
pub mod array;
|
||||
pub mod broadcast;
|
||||
pub mod factory;
|
||||
pub mod functions;
|
||||
pub mod indexing;
|
||||
pub mod mapping;
|
||||
|
@ -38,7 +39,7 @@ use inkwell::{
|
|||
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||
use util::{call_memcpy_model, gen_for_model_auto};
|
||||
|
||||
use super::AnyObject;
|
||||
use super::{tuple::TupleObject, AnyObject};
|
||||
|
||||
/// A NAC3 Python ndarray object.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
@ -229,6 +230,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
/// Get the pointer to the n-th (0-based) element.
|
||||
///
|
||||
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
|
||||
///
|
||||
/// There is no out-of-bounds check.
|
||||
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
|
@ -245,6 +248,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
}
|
||||
|
||||
/// Get the n-th (0-based) scalar.
|
||||
///
|
||||
/// There is no out-of-bounds check.
|
||||
pub fn get_nth<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
|
@ -256,6 +261,23 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ScalarObject { dtype: self.dtype, value }
|
||||
}
|
||||
|
||||
/// Set the n-th (0-based) scalar.
|
||||
///
|
||||
/// There is no out-of-bounds check.
|
||||
pub fn set_nth<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
nth: Int<'ctx, SizeT>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
) {
|
||||
// Sanity check on scalar's `dtype`
|
||||
assert!(ctx.unifier.unioned(scalar.dtype, self.dtype));
|
||||
|
||||
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
|
||||
ctx.builder.build_store(pscalar, scalar.value).unwrap();
|
||||
}
|
||||
|
||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
||||
///
|
||||
/// Please refer to the IRRT implementation to see its purpose.
|
||||
|
@ -363,6 +385,27 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ndarray
|
||||
}
|
||||
|
||||
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &[Int<'ctx, SizeT>],
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
|
||||
|
||||
// Write shape
|
||||
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, *dim);
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
|
||||
///
|
||||
/// The new ndarray will own its data and will be C-contiguous.
|
||||
|
@ -517,6 +560,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, SizeT>,
|
||||
Ptr<'ctx, IntModel<SizeT>>,
|
||||
PointerValue<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
|
@ -735,6 +779,64 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
output_shape,
|
||||
);
|
||||
}
|
||||
|
||||
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
|
||||
///
|
||||
/// The returned integers in the tuple are in int32.
|
||||
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> TupleObject<'ctx> {
|
||||
// TODO: Don't return a tuple of int32s.
|
||||
|
||||
let mut objects = Vec::with_capacity(self.ndims as usize);
|
||||
|
||||
for i in 0..self.ndims {
|
||||
let dim = self
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.shape, "")
|
||||
.offset_const(generator, ctx, i, "")
|
||||
.load(generator, ctx, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
objects.push(AnyObject {
|
||||
ty: ctx.primitives.int32,
|
||||
value: dim.value.as_basic_value_enum(),
|
||||
});
|
||||
}
|
||||
|
||||
TupleObject::create(generator, ctx, objects, "shape")
|
||||
}
|
||||
|
||||
/// Create the strides tuple of this ndarray like `np.strides(<ndarray>)`.
|
||||
///
|
||||
/// The returned integers in the tuple are in int32.
|
||||
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> TupleObject<'ctx> {
|
||||
// TODO: Don't return a tuple of int32s.
|
||||
|
||||
let mut objects = Vec::with_capacity(self.ndims as usize);
|
||||
|
||||
for i in 0..self.ndims {
|
||||
let dim = self
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.strides, "")
|
||||
.offset_const(generator, ctx, i, "")
|
||||
.load(generator, ctx, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
objects.push(AnyObject {
|
||||
ty: ctx.primitives.int32,
|
||||
value: dim.value.as_basic_value_enum(),
|
||||
});
|
||||
}
|
||||
|
||||
TupleObject::create(generator, ctx, objects, "strides")
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Document me
|
||||
|
|
|
@ -119,13 +119,14 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending
|
||||
/// on its [`Type`].
|
||||
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending on its [`Type`].
|
||||
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
object: AnyObject<'ctx>,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
// TODO: Automatically convert a list into an ndarray?
|
||||
|
||||
match &*ctx.unifier.get_ty(object.ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
|
|
|
@ -36,7 +36,7 @@ impl<'ctx> TupleObject<'ctx> {
|
|||
let value = object.value.into_struct_value();
|
||||
let value_num_fields = value.get_type().count_fields() as usize;
|
||||
assert!(
|
||||
value_num_fields != tys.len(),
|
||||
value_num_fields == tys.len(),
|
||||
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
||||
tys.len(),
|
||||
value_num_fields
|
||||
|
@ -87,7 +87,7 @@ impl<'ctx> TupleObject<'ctx> {
|
|||
|
||||
/// Get the `i`-th (0-based) object in this tuple.
|
||||
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
|
||||
assert!(i >= self.len(), "Tuple object with length {} have index {i}", self.len());
|
||||
assert!(i < self.len(), "Tuple object with length {} have index {i}", self.len());
|
||||
|
||||
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
|
||||
let ty = self.tys[i];
|
||||
|
|
|
@ -185,3 +185,30 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An IRRT helper structure used when iterating through an ndarray.
|
||||
/// Fields of [`IndicesIter`]
|
||||
pub struct IndicesIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
pub ndims: F::Out<IntModel<SizeT>>,
|
||||
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||
pub size: F::Out<IntModel<SizeT>>,
|
||||
pub nth: F::Out<IntModel<SizeT>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct IndicesIter;
|
||||
|
||||
impl<'ctx> StructKind<'ctx> for IndicesIter {
|
||||
type Fields<F: FieldTraversal<'ctx>> = IndicesIterFields<'ctx, F>;
|
||||
|
||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||
Self::Fields {
|
||||
ndims: traversal.add_auto("ndims"),
|
||||
shape: traversal.add_auto("shape"),
|
||||
indices: traversal.add_auto("indices"),
|
||||
size: traversal.add_auto("size"),
|
||||
nth: traversal.add_auto("nth"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2390,17 +2390,37 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
// The second argument is converted to an ndarray for implementation convenience.
|
||||
// TODO: Don't do that.
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
let x2 = ScalarObject { dtype: x2_ty, value: x2_val };
|
||||
let x2 = x2.as_ndarray(generator, ctx);
|
||||
|
||||
// The second argument is converted to an ndarray for implementation convenience.
|
||||
// TODO: Don't do that.
|
||||
|
||||
// Create a (1,)-shaped ndarray and put `x2` into it.
|
||||
let x2_ndarray = NDArrayObject::alloca_constant_shape(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
&[1],
|
||||
"x2_ndarray",
|
||||
);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
x2_ndarray.set_nth(
|
||||
generator,
|
||||
ctx,
|
||||
zero,
|
||||
ScalarObject { dtype: x2_ty, value: x2_val },
|
||||
);
|
||||
|
||||
// alloca_constant_shape
|
||||
|
||||
// let x2 = x2.as_ndarray(generator, ctx);
|
||||
|
||||
let [out] = perform_nalgebra_call(
|
||||
generator,
|
||||
ctx,
|
||||
[x1, x2],
|
||||
[x1, x2_ndarray],
|
||||
[2],
|
||||
|ctx, [x1, x2], [out]| {
|
||||
call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name()));
|
||||
|
|
Loading…
Reference in New Issue