forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: on iter

This commit is contained in:
lyken 2024-08-14 17:30:37 +08:00
parent fd78f7a0e8
commit 15dfb2eaa0
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
11 changed files with 414 additions and 119 deletions

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <irrt/int_defs.hpp> #include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace { namespace {
/** /**
@ -23,15 +24,24 @@ template <typename SizeT>
struct IndicesIter { struct IndicesIter {
SizeT ndims; SizeT ndims;
SizeT* shape; SizeT* shape;
SizeT* indices; SizeT* strides;
SizeT size; // Product of shape SizeT size; // Product of shape
SizeT nth; // The nth (0-based) index of the current indices.
IndicesIter(SizeT ndims, SizeT* shape, SizeT* indices) { SizeT* indices; // The current indices
SizeT nth; // The nth (0-based) index of the current indices.
uint8_t* element; // The current element
// A convenient constructor for internal C++ IRRT.
IndicesIter(SizeT ndims, SizeT* shape, SizeT* strides, SizeT *indices, uint8_t* element) {
this->ndims = ndims; this->ndims = ndims;
this->shape = shape; this->shape = shape;
this->strides = strides;
this->indices = indices; this->indices = indices;
this->element = element;
this->initialize();
}
void initialize() {
reset(); reset();
this->size = 1; this->size = 1;
@ -45,7 +55,7 @@ struct IndicesIter {
nth = 0; nth = 0;
} }
bool ok() { return nth < size; } bool has_next() { return nth < size; }
void next() { void next() {
for (SizeT i = 0; i < ndims; i++) { for (SizeT i = 0; i < ndims; i++) {
@ -60,4 +70,35 @@ struct IndicesIter {
nth++; nth++;
} }
}; };
} // namespace } // namespace
extern "C" {
void __call_nac3_ndarray_indices_iter_initialize(IndicesIter<int32_t>* iter,
int32_t ndims, int32_t* shape,
int32_t* indices) {
iter->initialize(ndims, shape, indices);
}
void __call_nac3_ndarray_indices_iter_initialize64(IndicesIter<int64_t>* iter,
int64_t ndims,
int64_t* shape,
int64_t* indices) {
iter->initialize(ndims, shape, indices);
}
bool __call_nac3_ndarray_indices_iter_has_next(IndicesIter<int32_t>* iter) {
iter->has_next();
}
bool __call_nac3_ndarray_indices_iter_has_next64(IndicesIter<int64_t>* iter) {
iter->has_next();
}
bool __call_nac3_ndarray_indices_iter_next(IndicesIter<int32_t>* iter) {
iter->next();
}
bool __call_nac3_ndarray_indices_iter_next64(IndicesIter<int64_t>* iter) {
iter->next();
}
}

View File

@ -125,7 +125,7 @@ void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
SizeT* mat_indices = indices + u; SizeT* mat_indices = indices + u;
IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices); IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices);
for (; iter.ok(); iter.next()) { for (; iter.has_next(); iter.next()) {
for (SizeT i = 0; i < dst_mat_shape[0]; i++) { for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
for (SizeT j = 0; j < dst_mat_shape[1]; j++) { for (SizeT j = 0; j < dst_mat_shape[1]; j++) {
// `indices` is being reused to index into different ndarrays. // `indices` is being reused to index into different ndarrays.

View File

@ -1238,3 +1238,8 @@ pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
get_sizet_dependent_function_name(generator, ctx, "__nac3_array_write_list_to_array"); get_sizet_dependent_function_name(generator, ctx, "__nac3_array_write_list_to_array");
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void(); CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
} }
pub fn call_nac3_ndarray_indices_iter_initialize<'ctx, G: CodeGenerator + ?Sized>(
) {
}

View File

@ -218,3 +218,5 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
self.gep(ctx, get_field).store(ctx, value); self.gep(ctx, get_field).store(ctx, value);
} }
} }
// TODO: Add an opaque struct type?

View File

@ -20,61 +20,6 @@ use super::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
/// Get the zero value in `np.zeros()` of a `dtype`.
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into()
} else {
unreachable!()
}
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into()
} else {
unreachable!()
}
}
/// Helper function to create an ndarray with uninitialized values. /// Helper function to create an ndarray with uninitialized values.
/// ///
@ -313,37 +258,9 @@ pub fn gen_ndarray_arange<'ctx>(
let input_ty = fun.0.args[0].ty; let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?.into_int_value(); let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?.into_int_value();
// Define models // Implementation
let sizet_model = IntModel(SizeT); let input_dim = IntModel(SizeT).s_extend_or_bit_cast(generator, ctx, input, "input_dim");
let ndarray = NDArrayObject::from_np_arange(generator, ctx, input_dim);
// Process input
let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim");
// Allocate the resulting ndarray
let ndarray = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
1, // ndims = 1
"arange_ndarray",
);
// `ndarray.shape[0] = input`
let zero = sizet_model.const_0(generator, ctx.ctx);
ndarray
.instance
.get(generator, ctx, |f| f.shape, "shape")
.offset(generator, ctx, zero.value, "dim")
.store(ctx, input);
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
let val =
ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap();
ctx.builder.build_store(pelement, val).unwrap();
Ok(())
})?;
Ok(ndarray.instance.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
@ -386,23 +303,7 @@ pub fn gen_ndarray_shape<'ctx>(
// Process ndarray // Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
Ok(ndarray.make_shape_tuple(generator, ctx).value.as_basic_value_enum())
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims {
let dim = ndarray
.instance
.get(generator, ctx, |f| f.shape, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
}
let shape = TupleObject::create(generator, ctx, objects, "shape");
Ok(shape.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `<ndarray>.strides`. /// Generates LLVM IR for `<ndarray>.strides`.

View File

@ -0,0 +1,196 @@
use inkwell::values::BasicValueEnum;
use super::{scalar::ScalarObject, NDArrayObject};
use crate::{
codegen::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
CodeGenerator,
},
typecheck::typedef::Type,
};
/// Get the zero value in `np.zeros()` of a `dtype`.
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into()
} else {
unreachable!()
}
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into()
} else {
unreachable!()
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an ndarray like `np.empty`.
pub fn from_np_empty<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// Validate `shape`
// TODO: Should the caller be responsible for this instead?
let ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, ndims);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "full_ndarray");
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
ndarray
}
/// Create an ndarray like `np.full`.
pub fn from_np_full<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
fill_value: ScalarObject<'ctx>,
) -> Self {
// Sanity check on `fill_value`'s dtype.
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value.value);
ndarray
}
/// Create an ndarray like `np.zero`.
pub fn from_np_zero<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_zero_value(generator, ctx, dtype);
let fill_value = ScalarObject { value: fill_value, dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.ones`.
pub fn from_np_ones<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_one_value(generator, ctx, dtype);
let fill_value = ScalarObject { value: fill_value, dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.arange`.
/// The returned ndarray's `dtype` is always `float`
pub fn from_np_arange<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
length: Int<'ctx, SizeT>,
) -> Self {
let ndarray = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
1, // ndims = 1
"arange_ndarray",
);
// `ndarray.shape[0] = length`
ndarray
.instance
.get(generator, ctx, |f| f.shape, "shape")
.offset_const(generator, ctx, 0, "dim")
.store(ctx, length);
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray
.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
let val = ctx
.builder
.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val")
.unwrap();
ctx.builder.build_store(pelement, val).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.eye`.
pub fn from_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
rows: Int<'ctx, SizeT>,
cols: Int<'ctx, SizeT>,
diagonal: Int<'ctx, SizeT>,
) -> Self {
let ndarray = NDArrayObject::alloca_dynamic_shape(
generator,
ctx,
dtype,
&[rows, cols],
"eye_ndarray",
);
ndarray
.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
todo!()
})
.unwrap();
todo!()
}
}

View File

@ -1,5 +1,6 @@
pub mod array; pub mod array;
pub mod broadcast; pub mod broadcast;
pub mod factory;
pub mod functions; pub mod functions;
pub mod indexing; pub mod indexing;
pub mod mapping; pub mod mapping;
@ -38,7 +39,7 @@ use inkwell::{
use scalar::{ScalarObject, ScalarOrNDArray}; use scalar::{ScalarObject, ScalarOrNDArray};
use util::{call_memcpy_model, gen_for_model_auto}; use util::{call_memcpy_model, gen_for_model_auto};
use super::AnyObject; use super::{tuple::TupleObject, AnyObject};
/// A NAC3 Python ndarray object. /// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -229,6 +230,8 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Get the pointer to the n-th (0-based) element. /// Get the pointer to the n-th (0-based) element.
/// ///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
///
/// There is no out-of-bounds check.
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>( pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -245,6 +248,8 @@ impl<'ctx> NDArrayObject<'ctx> {
} }
/// Get the n-th (0-based) scalar. /// Get the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn get_nth<G: CodeGenerator + ?Sized>( pub fn get_nth<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -256,6 +261,23 @@ impl<'ctx> NDArrayObject<'ctx> {
ScalarObject { dtype: self.dtype, value } ScalarObject { dtype: self.dtype, value }
} }
/// Set the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn set_nth<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
scalar: ScalarObject<'ctx>,
) {
// Sanity check on scalar's `dtype`
assert!(ctx.unifier.unioned(scalar.dtype, self.dtype));
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
ctx.builder.build_store(pscalar, scalar.value).unwrap();
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
/// ///
/// Please refer to the IRRT implementation to see its purpose. /// Please refer to the IRRT implementation to see its purpose.
@ -363,6 +385,27 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray ndarray
} }
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[Int<'ctx, SizeT>],
name: &str,
) -> Self {
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, *dim);
}
ndarray
}
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over. /// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
/// ///
/// The new ndarray will own its data and will be C-contiguous. /// The new ndarray will own its data and will be C-contiguous.
@ -517,6 +560,7 @@ impl<'ctx> NDArrayObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>, BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
Ptr<'ctx, IntModel<SizeT>>,
PointerValue<'ctx>, PointerValue<'ctx>,
) -> Result<(), String>, ) -> Result<(), String>,
{ {
@ -735,6 +779,64 @@ impl<'ctx> NDArrayObject<'ctx> {
output_shape, output_shape,
); );
} }
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "shape")
}
/// Create the strides tuple of this ndarray like `np.strides(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "strides")
}
} }
/// TODO: Document me /// TODO: Document me

View File

@ -119,13 +119,14 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
} }
} }
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending /// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending on its [`Type`].
/// on its [`Type`].
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>( pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>, object: AnyObject<'ctx>,
) -> ScalarOrNDArray<'ctx> { ) -> ScalarOrNDArray<'ctx> {
// TODO: Automatically convert a list into an ndarray?
match &*ctx.unifier.get_ty(object.ty) { match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. } TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>

View File

@ -36,7 +36,7 @@ impl<'ctx> TupleObject<'ctx> {
let value = object.value.into_struct_value(); let value = object.value.into_struct_value();
let value_num_fields = value.get_type().count_fields() as usize; let value_num_fields = value.get_type().count_fields() as usize;
assert!( assert!(
value_num_fields != tys.len(), value_num_fields == tys.len(),
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)", "Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
tys.len(), tys.len(),
value_num_fields value_num_fields
@ -87,7 +87,7 @@ impl<'ctx> TupleObject<'ctx> {
/// Get the `i`-th (0-based) object in this tuple. /// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> { pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
assert!(i >= self.len(), "Tuple object with length {} have index {i}", self.len()); assert!(i < self.len(), "Tuple object with length {} have index {i}", self.len());
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap(); let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i]; let ty = self.tys[i];

View File

@ -185,3 +185,30 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
} }
} }
} }
/// An IRRT helper structure used when iterating through an ndarray.
/// Fields of [`IndicesIter`]
pub struct IndicesIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
pub size: F::Out<IntModel<SizeT>>,
pub nth: F::Out<IntModel<SizeT>>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IndicesIter;
impl<'ctx> StructKind<'ctx> for IndicesIter {
type Fields<F: FieldTraversal<'ctx>> = IndicesIterFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
indices: traversal.add_auto("indices"),
size: traversal.add_auto("size"),
nth: traversal.add_auto("nth"),
}
}
}

View File

@ -2390,17 +2390,37 @@ impl<'a> BuiltinBuilder<'a> {
let x1 = AnyObject { ty: x1_ty, value: x1_val }; let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x1 = NDArrayObject::from_object(generator, ctx, x1); let x1 = NDArrayObject::from_object(generator, ctx, x1);
// The second argument is converted to an ndarray for implementation convenience.
// TODO: Don't do that.
let x2_ty = fun.0.args[1].ty; let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let x2 = ScalarObject { dtype: x2_ty, value: x2_val };
let x2 = x2.as_ndarray(generator, ctx); // The second argument is converted to an ndarray for implementation convenience.
// TODO: Don't do that.
// Create a (1,)-shaped ndarray and put `x2` into it.
let x2_ndarray = NDArrayObject::alloca_constant_shape(
generator,
ctx,
ctx.primitives.float,
&[1],
"x2_ndarray",
);
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
x2_ndarray.set_nth(
generator,
ctx,
zero,
ScalarObject { dtype: x2_ty, value: x2_val },
);
// alloca_constant_shape
// let x2 = x2.as_ndarray(generator, ctx);
let [out] = perform_nalgebra_call( let [out] = perform_nalgebra_call(
generator, generator,
ctx, ctx,
[x1, x2], [x1, x2_ndarray],
[2], [2],
|ctx, [x1, x2], [out]| { |ctx, [x1, x2], [out]| {
call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name())); call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name()));