core/ndstrides: introduce NDArrayObject & refactor reshape
This commit is contained in:
parent
7436513b64
commit
3dc4b17310
@ -1,3 +1,4 @@
|
|||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod object;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
68
nac3core/src/codegen/numpy_new/object.rs
Normal file
68
nac3core/src/codegen/numpy_new/object.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{model::*, structure::ndarray::NpArray, CodeGenContext},
|
||||||
|
toplevel::numpy::unpack_ndarray_var_tys,
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An LLVM ndarray instance with its typechecker [`Type`]s.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct NDArrayObject<'ctx> {
|
||||||
|
pub dtype: Type,
|
||||||
|
pub ndims: Type,
|
||||||
|
pub instance: Ptr<'ctx, StructModel<NpArray>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An LLVM numpy scalar with its [`Type`].
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ScalarObject<'ctx> {
|
||||||
|
pub dtype: Type,
|
||||||
|
pub value: BasicValueEnum<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum ScalarOrNDArray<'ctx> {
|
||||||
|
Scalar(ScalarObject<'ctx>),
|
||||||
|
NDArray(NDArrayObject<'ctx>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||||
|
fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
|
||||||
|
fn from(input: ScalarOrNDArray<'ctx>) -> BasicValueEnum<'ctx> {
|
||||||
|
input.to_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split an [`BasicValueEnum<'ctx>`] into a [`ScalarOrNDArray`] depending
|
||||||
|
/// on its [`Type`].
|
||||||
|
pub fn split_scalar_or_ndarray<'ctx>(
|
||||||
|
tyctx: TypeContext<'ctx>,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
input: BasicValueEnum<'ctx>,
|
||||||
|
input_ty: Type,
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
let pndarray_model = PtrModel(StructModel(NpArray));
|
||||||
|
|
||||||
|
let input_ty_enum = ctx.unifier.get_ty(input_ty);
|
||||||
|
match &*input_ty_enum {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let value = pndarray_model.check_value(tyctx, ctx.ctx, input).unwrap();
|
||||||
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, input_ty);
|
||||||
|
|
||||||
|
ScalarOrNDArray::NDArray(NDArrayObject { dtype, ndims, instance: value })
|
||||||
|
}
|
||||||
|
_ => ScalarOrNDArray::Scalar(ScalarObject { dtype: input_ty, value: input }),
|
||||||
|
}
|
||||||
|
}
|
@ -1,9 +1,11 @@
|
|||||||
use inkwell::{types::BasicType, values::BasicValueEnum};
|
use inkwell::types::BasicType;
|
||||||
|
use util::gen_model_memcpy;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
irrt::ndarray::basic::{
|
irrt::ndarray::basic::{
|
||||||
call_nac3_ndarray_get_nth_pelement, call_nac3_ndarray_nbytes,
|
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
|
||||||
|
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
|
||||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||||
call_nac3_ndarray_util_assert_shape_no_negative,
|
call_nac3_ndarray_util_assert_shape_no_negative,
|
||||||
},
|
},
|
||||||
@ -13,10 +15,31 @@ use crate::{
|
|||||||
util::{array_writer::ArrayWriter, control::gen_model_for},
|
util::{array_writer::ArrayWriter, control::gen_model_for},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
toplevel::numpy::unpack_ndarray_var_tys,
|
symbol_resolver::SymbolValue,
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
typecheck::typedef::{Type, TypeEnum, Unifier},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::object::{NDArrayObject, ScalarOrNDArray};
|
||||||
|
|
||||||
|
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||||
|
#[must_use]
|
||||||
|
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||||
|
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||||
|
panic!("ndims_ty should be a TLiteral");
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||||
|
|
||||||
|
let ndims = values[0].clone();
|
||||||
|
u64::try_from(ndims).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||||
|
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||||
|
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||||
|
}
|
||||||
|
|
||||||
/// Allocate an ndarray on the stack given its `ndims`.
|
/// Allocate an ndarray on the stack given its `ndims`.
|
||||||
///
|
///
|
||||||
/// `shape` and `strides` will be automatically allocated on the stack.
|
/// `shape` and `strides` will be automatically allocated on the stack.
|
||||||
@ -90,56 +113,6 @@ pub fn init_ndarray_data_by_alloca<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, pndarray);
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, pndarray);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert `input` to an ndarray - behaves similarly to `np.asarray`.
|
|
||||||
///
|
|
||||||
/// Returns the ndarray interpretation of `input` and **the element type** of the ndarray.
|
|
||||||
///
|
|
||||||
/// Here are the exact details:
|
|
||||||
/// - If `input` is an ndarray, the function returns back the **same** ndarray and the `dtype`
|
|
||||||
/// of the ndarray.
|
|
||||||
/// - If `input` is not an ndarray, the function creates an ndarray with a single element `input`,
|
|
||||||
/// and returns the created ndarray and `input_ty`. Note that the created ndarray's `ndims` will
|
|
||||||
/// be `0` (an *unsized* ndarray).
|
|
||||||
pub fn as_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
input: BasicValueEnum<'ctx>,
|
|
||||||
input_ty: Type,
|
|
||||||
) -> (Ptr<'ctx, StructModel<NpArray>>, Type) {
|
|
||||||
let tyctx = generator.type_context(ctx.ctx);
|
|
||||||
let sizet_model = IntModel(SizeT);
|
|
||||||
let pbyte_model = PtrModel(IntModel(Byte));
|
|
||||||
let pndarray_model = PtrModel(StructModel(NpArray));
|
|
||||||
|
|
||||||
let input_ty_enum = ctx.unifier.get_ty(input_ty);
|
|
||||||
match &*input_ty_enum {
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
let pndarray = pndarray_model.check_value(tyctx, ctx.ctx, input).unwrap();
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, input_ty);
|
|
||||||
(pndarray, elem_ty)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let ndims = sizet_model.const_0(tyctx, ctx.ctx);
|
|
||||||
let pndarray = alloca_ndarray(generator, ctx, ndims, "ndarray");
|
|
||||||
|
|
||||||
// We have to put `input` onto the stack to get a data pointer.
|
|
||||||
let data = ctx.builder.build_alloca(input.get_type(), "as_ndarray_scalar").unwrap();
|
|
||||||
ctx.builder.build_store(data, input).unwrap();
|
|
||||||
|
|
||||||
let data = pbyte_model.transmute(tyctx, ctx, data, "data");
|
|
||||||
pndarray.gep(ctx, |f| f.data).store(ctx, data);
|
|
||||||
|
|
||||||
let itemsize = input.get_type().size_of().unwrap();
|
|
||||||
let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap();
|
|
||||||
pndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
|
|
||||||
|
|
||||||
(pndarray, input_ty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Iterate through all elements in an ndarray.
|
/// Iterate through all elements in an ndarray.
|
||||||
///
|
///
|
||||||
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
|
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
|
||||||
@ -180,3 +153,141 @@ where
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Convert `input` to an ndarray - behaves like `np.asarray`.
|
||||||
|
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let pbyte_model = PtrModel(IntModel(Byte));
|
||||||
|
|
||||||
|
// We have to put the value on the stack to get a data pointer.
|
||||||
|
let data =
|
||||||
|
ctx.builder.build_alloca(scalar.value.get_type(), "as_ndarray_scalar").unwrap();
|
||||||
|
ctx.builder.build_store(data, scalar.value).unwrap();
|
||||||
|
let data = pbyte_model.transmute(tyctx, ctx, data, "data");
|
||||||
|
|
||||||
|
let ndims_ty = create_ndims(&mut ctx.unifier, 0);
|
||||||
|
let ndarray = NDArrayObject::alloca(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndims_ty,
|
||||||
|
scalar.dtype,
|
||||||
|
"scalar_as_ndarray",
|
||||||
|
);
|
||||||
|
ndarray.instance.gep(ctx, |f| f.data).store(ctx, data);
|
||||||
|
|
||||||
|
// No need to initialize/setup strides or shapes - because `ndims` is 0.
|
||||||
|
// So we only have to set `data`, `itemsize`, and `ndims = 0`.
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
pub fn alloca<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndims: Type,
|
||||||
|
dtype: Type,
|
||||||
|
name: &str,
|
||||||
|
) -> Self {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
let ndims_int = sizet_model.constant(tyctx, ctx.ctx, extract_ndims(&ctx.unifier, ndims));
|
||||||
|
let instance = alloca_ndarray(generator, ctx, ndims_int, name);
|
||||||
|
|
||||||
|
// Set itemsize
|
||||||
|
let dtype_ty = ctx.get_llvm_type(generator, dtype);
|
||||||
|
let itemsize = dtype_ty.size_of().unwrap();
|
||||||
|
let itemsize = sizet_model.s_extend_or_bit_cast(tyctx, ctx, itemsize, "itemsize");
|
||||||
|
instance.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
|
||||||
|
|
||||||
|
NDArrayObject { dtype, ndims, instance }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn copy_shape<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
|
) {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
let self_shape = self.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "self_shape");
|
||||||
|
let ndims_int =
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, extract_ndims(&ctx.unifier, self.ndims));
|
||||||
|
gen_model_memcpy(tyctx, ctx, self_shape, src_shape, ndims_int.value, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn copy_shape_from<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayObject<'ctx>,
|
||||||
|
) {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let src_shape = src_ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "src_shape");
|
||||||
|
self.copy_shape(generator, ctx, src_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) {
|
||||||
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Int<'ctx, SizeT> {
|
||||||
|
call_nac3_ndarray_size(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nbytes<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Int<'ctx, SizeT> {
|
||||||
|
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Int<'ctx, Bool> {
|
||||||
|
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloca_owned_data<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) {
|
||||||
|
init_ndarray_data_by_alloca(generator, ctx, self.instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src: NDArrayObject<'ctx>,
|
||||||
|
) {
|
||||||
|
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
|
||||||
|
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,90 +3,67 @@ use nac3parser::ast::StrRef;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
irrt::ndarray::{
|
irrt::ndarray::reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||||
basic::{
|
|
||||||
call_nac3_ndarray_copy_data, call_nac3_ndarray_is_c_contiguous,
|
|
||||||
call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape,
|
|
||||||
call_nac3_ndarray_size,
|
|
||||||
},
|
|
||||||
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
|
||||||
},
|
|
||||||
model::*,
|
model::*,
|
||||||
numpy_new::util::{alloca_ndarray, init_ndarray_shape},
|
numpy_new::{object::split_scalar_or_ndarray, util::extract_ndims},
|
||||||
structure::ndarray::NpArray,
|
util::shape::make_shape_writer,
|
||||||
util::{array_writer::ArrayWriter, shape::make_shape_writer},
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::DefinitionId,
|
toplevel::{numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||||
typecheck::typedef::{FunSignature, Type},
|
typecheck::typedef::{FunSignature, Type},
|
||||||
};
|
};
|
||||||
|
|
||||||
fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
use super::object::NDArrayObject;
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
new_ndims: Type,
|
||||||
new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>,
|
new_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String> {
|
) -> Self {
|
||||||
let tyctx = generator.type_context(ctx.ctx);
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
let byte_model = IntModel(Byte);
|
|
||||||
|
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
||||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||||
|
|
||||||
// Inserting into current_bb
|
let dst_ndarray =
|
||||||
let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray");
|
NDArrayObject::alloca(generator, ctx, new_ndims, self.dtype, "reshaped_ndarray");
|
||||||
|
dst_ndarray.copy_shape(generator, ctx, new_shape);
|
||||||
|
dst_ndarray.update_strides_by_shape(generator, ctx);
|
||||||
|
|
||||||
// Set shape - directly from user input
|
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
||||||
init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?;
|
|
||||||
dst_ndarray
|
|
||||||
.gep(ctx, |f| f.itemsize)
|
|
||||||
.store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize"));
|
|
||||||
|
|
||||||
// Resolve shape input from user
|
|
||||||
let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray);
|
|
||||||
call_nac3_ndarray_resolve_and_check_new_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
src_ndarray_size,
|
|
||||||
dst_ndarray.gep(ctx, |f| f.ndims).load(tyctx, ctx, "ndims"),
|
|
||||||
dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Update strides
|
|
||||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
|
|
||||||
|
|
||||||
let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray);
|
|
||||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||||
|
|
||||||
// Inserting into then_bb: reshape is possible without copying
|
// Inserting into then_bb: reshape is possible without copying
|
||||||
ctx.builder.position_at_end(then_bb);
|
ctx.builder.position_at_end(then_bb);
|
||||||
dst_ndarray
|
dst_ndarray
|
||||||
|
.instance
|
||||||
.gep(ctx, |f| f.data)
|
.gep(ctx, |f| f.data)
|
||||||
.store(ctx, src_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "data"));
|
.store(ctx, dst_ndarray.instance.gep(ctx, |f| f.data).load(tyctx, ctx, "data"));
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
// Inserting into else_bb: reshape is impossible without copying
|
// Inserting into else_bb: reshape is impossible without copying
|
||||||
ctx.builder.position_at_end(else_bb);
|
ctx.builder.position_at_end(else_bb);
|
||||||
// Allocate data
|
dst_ndarray.alloca_owned_data(generator, ctx);
|
||||||
let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray);
|
dst_ndarray.copy_data_from(generator, ctx, *self);
|
||||||
let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data");
|
|
||||||
dst_ndarray.gep(ctx, |f| f.data).store(ctx, data);
|
|
||||||
// Copy content
|
|
||||||
call_nac3_ndarray_copy_data(generator, ctx, src_ndarray, dst_ndarray);
|
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
// Reposition for continuation
|
// Reposition for continuation
|
||||||
ctx.builder.position_at_end(end_bb);
|
ctx.builder.position_at_end(end_bb);
|
||||||
|
|
||||||
Ok(dst_ndarray)
|
dst_ndarray
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `np.reshape`.
|
/// Generates LLVM IR for `np.reshape`.
|
||||||
pub fn gen_ndarray_reshape<'ctx>(
|
pub fn gen_ndarray_reshape<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
@ -95,21 +72,43 @@ pub fn gen_ndarray_reshape<'ctx>(
|
|||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 2);
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
// Parse argument #1 ndarray
|
// Parse argument #1 input
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
let input_ty = fun.0.args[0].ty;
|
||||||
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
|
let input_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||||
|
|
||||||
// Parse argument #2 shape
|
// Parse argument #2 shape
|
||||||
let shape_ty = fun.0.args[1].ty;
|
let shape_ty = fun.0.args[1].ty;
|
||||||
let shape_arg = args[1].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
|
||||||
let tyctx = generator.type_context(context.ctx);
|
// Define models
|
||||||
let pndarray_model = PtrModel(StructModel(NpArray));
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
let src_ndarray = pndarray_model.check_value(tyctx, context.ctx, ndarray_arg).unwrap();
|
// Extract reshaped_ndims
|
||||||
let new_shape = make_shape_writer(generator, context, shape_arg, shape_ty);
|
let (_, reshaped_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let reshaped_ndims_int = extract_ndims(&ctx.unifier, reshaped_ndims);
|
||||||
|
|
||||||
let reshaped_ndarray =
|
// Process `input`
|
||||||
gen_reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
let ndarray =
|
||||||
Ok(reshaped_ndarray.value)
|
split_scalar_or_ndarray(tyctx, ctx, input_arg, input_ty).as_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
// Process the shape input from user and resolve negative indices
|
||||||
|
let new_shape = make_shape_writer(generator, ctx, shape_arg, shape_ty).alloca_array_and_write(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"new_shape",
|
||||||
|
)?;
|
||||||
|
let size = ndarray.size(generator, ctx);
|
||||||
|
call_nac3_ndarray_resolve_and_check_new_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
size,
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, reshaped_ndims_int),
|
||||||
|
new_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Reshape
|
||||||
|
let reshaped_ndarray = ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
|
||||||
|
|
||||||
|
Ok(reshaped_ndarray.instance.value)
|
||||||
}
|
}
|
||||||
|
@ -15,3 +15,20 @@ pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Mode
|
|||||||
+ 'ctx,
|
+ 'ctx,
|
||||||
>,
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> ArrayWriter<'ctx, G, Len, Item> {
|
||||||
|
pub fn alloca_array_and_write(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Ptr<'ctx, Item>, String> {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
|
||||||
|
let item_model = Item::default();
|
||||||
|
|
||||||
|
let item_array = item_model.array_alloca(tyctx, ctx, self.len.value, name);
|
||||||
|
(self.write)(generator, ctx, item_array)?;
|
||||||
|
Ok(item_array)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user