WIP - core/ndstrides: implement np_{zeros,ones,full,empty}

This commit is contained in:
David Mak 2024-12-16 15:26:18 +08:00
parent 28aaafb38e
commit e70805eeaa
6 changed files with 363 additions and 68 deletions

View File

@ -4,7 +4,7 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use nac3core::codegen::values::ndarray::shape::parse_numpy_int_sequence;
use nac3parser::ast::{Operator, StrRef};
use super::{
@ -19,11 +19,12 @@ use super::{
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
types::{ndarray::NDArrayType, ListType, ProxyType},
types::{ndarray::{factory::{ndarray_one_value, ndarray_zero_value}, NDArrayType}, ListType, ProxyType},
values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
ndarray::NDArrayValue,
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
},
CodeGenContext, CodeGenerator,
};
@ -35,6 +36,7 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum},
},
};
use crate::toplevel::helper::extract_ndims;
/// Creates an `NDArray` instance from a dynamic shape.
///
@ -174,60 +176,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
codegen_unreachable!(ctx)
}
}
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
codegen_unreachable!(ctx)
}
}
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -1723,8 +1671,19 @@ pub fn gen_ndarray_empty<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_dtype = context.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(
generator,
context,
(shape_ty, shape_arg),
);
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
.construct_numpy_empty(generator, context, shape);
Ok(ndarray.as_base_value())
}
/// Generates LLVM IR for `ndarray.zeros`.
@ -1741,8 +1700,19 @@ pub fn gen_ndarray_zeros<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_dtype = context.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(
generator,
context,
(shape_ty, shape_arg),
);
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
.construct_numpy_zeros(generator, context, dtype, shape);
Ok(ndarray.as_base_value())
}
/// Generates LLVM IR for `ndarray.ones`.
@ -1759,8 +1729,19 @@ pub fn gen_ndarray_ones<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_dtype = context.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(
generator,
context,
(shape_ty, shape_arg),
);
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
.construct_numpy_ones(generator, context, dtype, shape);
Ok(ndarray.as_base_value())
}
/// Generates LLVM IR for `ndarray.full`.
@ -1780,8 +1761,19 @@ pub fn gen_ndarray_full<'ctx>(
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_dtype = context.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = parse_numpy_int_sequence(
generator,
context,
(shape_ty, shape_arg),
);
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
.construct_numpy_full(generator, context, shape, fill_value_arg);
Ok(ndarray.as_base_value())
}
pub fn gen_ndarray_array<'ctx>(

View File

@ -0,0 +1,139 @@
use crate::codegen::{irrt, CodeGenContext, CodeGenerator};
use crate::typecheck::typedef::Type;
use inkwell::values::{BasicValueEnum, IntValue};
use nac3core::codegen::values::ndarray::NDArrayValue;
use crate::codegen::types::ndarray::NDArrayType;
use crate::codegen::types::ProxyType;
use crate::codegen::values::TypedArrayLikeAccessor;
/// Get the zero value in `np.zeros()` of a `dtype`.
// TODO: Make this non-pub
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
/// Get the one value in `np.ones()` of a `dtype`.
// TODO: Make this non-pub
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
impl<'ctx> NDArrayType<'ctx> {
/// Create an ndarray like `np.empty`.
pub fn construct_numpy_empty<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_uninitialized(generator, ctx, None);
// Validate `shape`
let ndims_llvm = self.llvm_usize.const_int(self.ndims.unwrap(), false);
irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape.base_ptr(ctx, generator));
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { ndarray.create_data(generator, ctx) };
ndarray
}
/// Create an ndarray like `np.full`.
pub fn construct_numpy_full<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
fill_value: BasicValueEnum<'ctx>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.construct_numpy_empty(generator, ctx, shape);
ndarray.fill(generator, ctx, fill_value);
ndarray
}
/// Create an ndarray like `np.zero`.
pub fn construct_numpy_zeros<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
) -> <Self as ProxyType<'ctx>>::Value {
assert_eq!(
ctx.get_llvm_type(generator, dtype),
self.dtype,
"Expected LLVM dtype={} but got {}",
self.dtype.print_to_string(),
ctx.get_llvm_type(generator, dtype).print_to_string(),
);
let fill_value = ndarray_zero_value(generator, ctx, dtype);
self.construct_numpy_full(generator, ctx, shape, fill_value)
}
/// Create an ndarray like `np.ones`.
pub fn construct_numpy_ones<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>,
) -> <Self as ProxyType<'ctx>>::Value {
assert_eq!(
ctx.get_llvm_type(generator, dtype),
self.dtype,
"Expected LLVM dtype={} but got {}",
self.dtype.print_to_string(),
ctx.get_llvm_type(generator, dtype).print_to_string(),
);
let fill_value = ndarray_one_value(generator, ctx, dtype);
self.construct_numpy_full(generator, ctx, shape, fill_value)
}
}

View File

@ -25,6 +25,7 @@ pub use indexing::*;
pub use nditer::*;
mod contiguous;
pub mod factory;
mod indexing;
mod nditer;

View File

@ -135,8 +135,10 @@ impl<'ctx> NDIterType<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> <Self as ProxyType<'ctx>>::Value {
assert!(ndarray.get_type().ndims().is_some(), "NDIter requires ndims of NDArray to be known.");
let nditer = self.raw_alloca(generator, ctx, None);
let ndims = ndarray.load_ndims(ctx);
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false);
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices =

View File

@ -24,6 +24,7 @@ pub use view::*;
mod contiguous;
mod indexing;
mod nditer;
pub mod shape;
mod view;
/// Proxy type for accessing an `NDArray` value in LLVM.
@ -406,6 +407,23 @@ impl<'ctx> NDArrayValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
}
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) {
self.foreach(generator, ctx, |_, ctx, _hooks, nditer| {
let p = nditer.get_pointer(ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
}
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
#[must_use]
pub fn is_unsized(&self) -> Option<bool> {

View File

@ -0,0 +1,143 @@
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::types::{ListType, TupleType};
use crate::codegen::values::{ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor};
use crate::codegen::{CodeGenContext, CodeGenerator};
use crate::typecheck::typedef::{Type, TypeEnum};
use inkwell::values::{BasicValue, BasicValueEnum, IntValue};
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
///
/// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence`
///
/// The `sequence` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let zero = llvm_usize.const_zero();
let one = llvm_usize.const_int(1, false);
// The result `list` to return.
match &*ctx.unifier.get_ty_immutable(input_seq_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty)
.map_value(input_seq.into_pointer_value(), None);
let len = input_seq.load_size(ctx, None);
// TODO: Find a way to remove this mid-BB allocation
let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap();
let result = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(result, len, None),
Box::new(|_, val| val.into_int_value()),
Box::new(|_, val| val.as_basic_value_enum()),
);
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_callback_incrementing(
generator,
ctx,
None,
zero,
(len, false),
|generator, ctx, _, i| {
// Load the i-th int32 in the input sequence
let int = unsafe {
input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value()
};
// Cast to SizeT
let int =
ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
// Store
unsafe { result.set_typed_unchecked(ctx, generator, &i, int) };
Ok(())
},
one,
)
.unwrap();
result
}
TypeEnum::TTuple { .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty)
.map_value(input_seq.into_struct_value(), None);
let len = input_seq.get_type().num_elements();
let result = generator
.gen_array_var_alloc(
ctx,
llvm_usize.into(),
llvm_usize.const_int(u64::from(len), false),
None,
)
.unwrap();
let result = TypedArrayLikeAdapter::from(
result,
Box::new(|_, val| val.into_int_value()),
Box::new(|_, val| val.as_basic_value_enum()),
);
for i in 0..input_seq.get_type().num_elements() {
// Get the i-th element off of the tuple and load it into `result`.
let int = input_seq.load_element(ctx, i).into_int_value();
let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
unsafe {
result.set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(u64::from(i), false),
int,
);
}
}
result
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let input_int = input_seq.into_int_value();
let len = one;
let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap();
let result = TypedArrayLikeAdapter::from(
result,
Box::new(|_, val| val.into_int_value()),
Box::new(|_, val| val.as_basic_value_enum()),
);
let int =
ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap();
// Storing into result[0]
unsafe {
result.set_typed_unchecked(ctx, generator, &zero, int);
}
result
}
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)),
}
}