forked from M-Labs/nac3
WIP - core/ndstrides: implement np_{zeros,ones,full,empty}
This commit is contained in:
parent
28aaafb38e
commit
e70805eeaa
@ -4,7 +4,7 @@ use inkwell::{
|
|||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use nac3core::codegen::values::ndarray::shape::parse_numpy_int_sequence;
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
@ -19,11 +19,12 @@ use super::{
|
|||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
types::{ndarray::{factory::{ndarray_one_value, ndarray_zero_value}, NDArrayType}, ListType, ProxyType},
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
|
ndarray::NDArrayValue,
|
||||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, TypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||||
|
UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
@ -35,6 +36,7 @@ use crate::{
|
|||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use crate::toplevel::helper::extract_ndims;
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a dynamic shape.
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
///
|
///
|
||||||
@ -174,60 +176,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i32_type().const_zero().into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "").into()
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
|
||||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
|
||||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_float(1.0).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_int(1, false).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "1").into()
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
@ -1723,8 +1671,19 @@ pub fn gen_ndarray_empty<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
(shape_ty, shape_arg),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_empty(generator, context, shape);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.zeros`.
|
/// Generates LLVM IR for `ndarray.zeros`.
|
||||||
@ -1741,8 +1700,19 @@ pub fn gen_ndarray_zeros<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
(shape_ty, shape_arg),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_zeros(generator, context, dtype, shape);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.ones`.
|
/// Generates LLVM IR for `ndarray.ones`.
|
||||||
@ -1759,8 +1729,19 @@ pub fn gen_ndarray_ones<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
(shape_ty, shape_arg),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_ones(generator, context, dtype, shape);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.full`.
|
/// Generates LLVM IR for `ndarray.full`.
|
||||||
@ -1780,8 +1761,19 @@ pub fn gen_ndarray_full<'ctx>(
|
|||||||
let fill_value_arg =
|
let fill_value_arg =
|
||||||
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
(shape_ty, shape_arg),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_full(generator, context, shape, fill_value_arg);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen_ndarray_array<'ctx>(
|
pub fn gen_ndarray_array<'ctx>(
|
||||||
|
139
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
139
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
use crate::codegen::{irrt, CodeGenContext, CodeGenerator};
|
||||||
|
use crate::typecheck::typedef::Type;
|
||||||
|
use inkwell::values::{BasicValueEnum, IntValue};
|
||||||
|
use nac3core::codegen::values::ndarray::NDArrayValue;
|
||||||
|
use crate::codegen::types::ndarray::NDArrayType;
|
||||||
|
use crate::codegen::types::ProxyType;
|
||||||
|
use crate::codegen::values::TypedArrayLikeAccessor;
|
||||||
|
|
||||||
|
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||||
|
// TODO: Make this non-pub
|
||||||
|
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
ctx.ctx.i32_type().const_zero().into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
ctx.ctx.i64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "").into()
|
||||||
|
} else {
|
||||||
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the one value in `np.ones()` of a `dtype`.
|
||||||
|
// TODO: Make this non-pub
|
||||||
|
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||||
|
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||||
|
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_float(1.0).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_int(1, false).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "1").into()
|
||||||
|
} else {
|
||||||
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
/// Create an ndarray like `np.empty`.
|
||||||
|
pub fn construct_numpy_empty<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let ndarray = self.construct_uninitialized(generator, ctx, None);
|
||||||
|
|
||||||
|
// Validate `shape`
|
||||||
|
let ndims_llvm = self.llvm_usize.const_int(self.ndims.unwrap(), false);
|
||||||
|
irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape.base_ptr(ctx, generator));
|
||||||
|
|
||||||
|
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||||
|
unsafe { ndarray.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.full`.
|
||||||
|
pub fn construct_numpy_full<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let ndarray = self.construct_numpy_empty(generator, ctx, shape);
|
||||||
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.zero`.
|
||||||
|
pub fn construct_numpy_zeros<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert_eq!(
|
||||||
|
ctx.get_llvm_type(generator, dtype),
|
||||||
|
self.dtype,
|
||||||
|
"Expected LLVM dtype={} but got {}",
|
||||||
|
self.dtype.print_to_string(),
|
||||||
|
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||||
|
self.construct_numpy_full(generator, ctx, shape, fill_value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.ones`.
|
||||||
|
pub fn construct_numpy_ones<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert_eq!(
|
||||||
|
ctx.get_llvm_type(generator, dtype),
|
||||||
|
self.dtype,
|
||||||
|
"Expected LLVM dtype={} but got {}",
|
||||||
|
self.dtype.print_to_string(),
|
||||||
|
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||||
|
self.construct_numpy_full(generator, ctx, shape, fill_value)
|
||||||
|
}
|
||||||
|
}
|
@ -25,6 +25,7 @@ pub use indexing::*;
|
|||||||
pub use nditer::*;
|
pub use nditer::*;
|
||||||
|
|
||||||
mod contiguous;
|
mod contiguous;
|
||||||
|
pub mod factory;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod nditer;
|
mod nditer;
|
||||||
|
|
||||||
|
@ -135,8 +135,10 @@ impl<'ctx> NDIterType<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert!(ndarray.get_type().ndims().is_some(), "NDIter requires ndims of NDArray to be known.");
|
||||||
|
|
||||||
let nditer = self.raw_alloca(generator, ctx, None);
|
let nditer = self.raw_alloca(generator, ctx, None);
|
||||||
let ndims = ndarray.load_ndims(ctx);
|
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false);
|
||||||
|
|
||||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||||
let indices =
|
let indices =
|
||||||
|
@ -24,6 +24,7 @@ pub use view::*;
|
|||||||
mod contiguous;
|
mod contiguous;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod nditer;
|
mod nditer;
|
||||||
|
pub mod shape;
|
||||||
mod view;
|
mod view;
|
||||||
|
|
||||||
/// Proxy type for accessing an `NDArray` value in LLVM.
|
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||||
@ -406,6 +407,23 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fill the ndarray with a scalar.
|
||||||
|
///
|
||||||
|
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
||||||
|
pub fn fill<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) {
|
||||||
|
self.foreach(generator, ctx, |_, ctx, _hooks, nditer| {
|
||||||
|
let p = nditer.get_pointer(ctx);
|
||||||
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn is_unsized(&self) -> Option<bool> {
|
pub fn is_unsized(&self) -> Option<bool> {
|
||||||
|
143
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
143
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
|
use crate::codegen::types::{ListType, TupleType};
|
||||||
|
use crate::codegen::values::{ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor};
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
use crate::typecheck::typedef::{Type, TypeEnum};
|
||||||
|
use inkwell::values::{BasicValue, BasicValueEnum, IntValue};
|
||||||
|
|
||||||
|
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||||
|
///
|
||||||
|
/// * `sequence` - The `sequence` parameter.
|
||||||
|
/// * `sequence_ty` - The typechecker type of `sequence`
|
||||||
|
///
|
||||||
|
/// The `sequence` argument type may only be one of the following:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// All `int32` values will be sign-extended to `SizeT`.
|
||||||
|
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let zero = llvm_usize.const_zero();
|
||||||
|
let one = llvm_usize.const_int(1, false);
|
||||||
|
|
||||||
|
// The result `list` to return.
|
||||||
|
match &*ctx.unifier.get_ty_immutable(input_seq_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||||
|
.map_value(input_seq.into_pointer_value(), None);
|
||||||
|
|
||||||
|
let len = input_seq.load_size(ctx, None);
|
||||||
|
// TODO: Find a way to remove this mid-BB allocation
|
||||||
|
let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(result, len, None),
|
||||||
|
Box::new(|_, val| val.into_int_value()),
|
||||||
|
Box::new(|_, val| val.as_basic_value_enum()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
zero,
|
||||||
|
(len, false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
// Load the i-th int32 in the input sequence
|
||||||
|
let int = unsafe {
|
||||||
|
input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Cast to SizeT
|
||||||
|
let int =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
// Store
|
||||||
|
unsafe { result.set_typed_unchecked(ctx, generator, &i, int) };
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
one,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeEnum::TTuple { .. } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||||
|
.map_value(input_seq.into_struct_value(), None);
|
||||||
|
|
||||||
|
let len = input_seq.get_type().num_elements();
|
||||||
|
|
||||||
|
let result = generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_usize.const_int(u64::from(len), false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
result,
|
||||||
|
Box::new(|_, val| val.into_int_value()),
|
||||||
|
Box::new(|_, val| val.as_basic_value_enum()),
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..input_seq.get_type().num_elements() {
|
||||||
|
// Get the i-th element off of the tuple and load it into `result`.
|
||||||
|
let int = input_seq.load_element(ctx, i).into_int_value();
|
||||||
|
let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
result.set_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(u64::from(i), false),
|
||||||
|
int,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
let input_int = input_seq.into_int_value();
|
||||||
|
|
||||||
|
let len = one;
|
||||||
|
let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
result,
|
||||||
|
Box::new(|_, val| val.into_int_value()),
|
||||||
|
Box::new(|_, val| val.as_basic_value_enum()),
|
||||||
|
);
|
||||||
|
let int =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
// Storing into result[0]
|
||||||
|
unsafe {
|
||||||
|
result.set_typed_unchecked(ctx, generator, &zero, int);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)),
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user