From e70805eeaaf4b3267f9e9f31effa420627aa255c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 15:26:18 +0800 Subject: [PATCH] WIP - core/ndstrides: implement np_{zeros,ones,full,empty} --- nac3core/src/codegen/numpy.rs | 126 ++++++++------- nac3core/src/codegen/types/ndarray/factory.rs | 139 +++++++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/types/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/ndarray/mod.rs | 18 +++ nac3core/src/codegen/values/ndarray/shape.rs | 143 ++++++++++++++++++ 6 files changed, 363 insertions(+), 68 deletions(-) create mode 100644 nac3core/src/codegen/types/ndarray/factory.rs create mode 100644 nac3core/src/codegen/values/ndarray/shape.rs diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ea025fcb..57c85da6 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -4,7 +4,7 @@ use inkwell::{ AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; - +use nac3core::codegen::values::ndarray::shape::parse_numpy_int_sequence; use nac3parser::ast::{Operator, StrRef}; use super::{ @@ -19,11 +19,12 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ndarray::NDArrayType, ListType, ProxyType}, + types::{ndarray::{factory::{ndarray_one_value, ndarray_zero_value}, NDArrayType}, ListType, ProxyType}, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ndarray::NDArrayValue, + ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -35,6 +36,7 @@ use crate::{ typedef::{FunSignature, Type, TypeEnum}, }, }; +use crate::toplevel::helper::extract_ndims; /// Creates an `NDArray` instance from a dynamic shape. /// @@ -174,60 +176,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - /// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1723,8 +1671,19 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence( + generator, + context, + (shape_ty, shape_arg), + ); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_empty(generator, context, shape); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1741,8 +1700,19 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence( + generator, + context, + (shape_ty, shape_arg), + ); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_zeros(generator, context, dtype, shape); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.ones`. @@ -1759,8 +1729,19 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence( + generator, + context, + (shape_ty, shape_arg), + ); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_ones(generator, context, dtype, shape); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.full`. @@ -1780,8 +1761,19 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence( + generator, + context, + (shape_ty, shape_arg), + ); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_full(generator, context, shape, fill_value_arg); + Ok(ndarray.as_base_value()) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs new file mode 100644 index 00000000..d64f51df --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -0,0 +1,139 @@ +use crate::codegen::{irrt, CodeGenContext, CodeGenerator}; +use crate::typecheck::typedef::Type; +use inkwell::values::{BasicValueEnum, IntValue}; +use nac3core::codegen::values::ndarray::NDArrayValue; +use crate::codegen::types::ndarray::NDArrayType; +use crate::codegen::types::ProxyType; +use crate::codegen::values::TypedArrayLikeAccessor; + +/// Get the zero value in `np.zeros()` of a `dtype`. +// TODO: Make this non-pub +pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +// TODO: Make this non-pub +pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayType<'ctx> { + /// Create an ndarray like `np.empty`. + pub fn construct_numpy_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, None); + + // Validate `shape` + let ndims_llvm = self.llvm_usize.const_int(self.ndims.unwrap(), false); + irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape.base_ptr(ctx, generator)); + + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + ndarray + } + + /// Create an ndarray like `np.full`. + pub fn construct_numpy_full( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>, + fill_value: BasicValueEnum<'ctx>, + ) -> >::Value { + let ndarray = self.construct_numpy_empty(generator, ctx, shape); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like `np.zero`. + pub fn construct_numpy_zeros( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_zero_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value) + } + + /// Create an ndarray like `np.ones`. + pub fn construct_numpy_ones( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_one_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value) + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index b655f352..927e38bd 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -25,6 +25,7 @@ pub use indexing::*; pub use nditer::*; mod contiguous; +pub mod factory; mod indexing; mod nditer; diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 33ad9250..aef10391 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -135,8 +135,10 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { + assert!(ndarray.get_type().ndims().is_some(), "NDIter requires ndims of NDArray to be known."); + let nditer = self.raw_alloca(generator, ctx, None); - let ndims = ndarray.load_ndims(ctx); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 12fd8634..d52114d7 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -24,6 +24,7 @@ pub use view::*; mod contiguous; mod indexing; mod nditer; +pub mod shape; mod view; /// Proxy type for accessing an `NDArray` value in LLVM. @@ -406,6 +407,23 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); } + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |_, ctx, _hooks, nditer| { + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs new file mode 100644 index 00000000..b223fe66 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -0,0 +1,143 @@ +use crate::codegen::stmt::gen_for_callback_incrementing; +use crate::codegen::types::{ListType, TupleType}; +use crate::codegen::values::{ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor}; +use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::typecheck::typedef::{Type, TypeEnum}; +use inkwell::values::{BasicValue, BasicValueEnum, IntValue}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), +) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let zero = llvm_usize.const_zero(); + let one = llvm_usize.const_int(1, false); + + // The result `list` to return. + match &*ctx.unifier.get_ty_immutable(input_seq_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_pointer_value(), None); + + let len = input_seq.load_size(ctx, None); + // TODO: Find a way to remove this mid-BB allocation + let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap(); + let result = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(result, len, None), + Box::new(|_, val| val.into_int_value()), + Box::new(|_, val| val.as_basic_value_enum()), + ); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_callback_incrementing( + generator, + ctx, + None, + zero, + (len, false), + |generator, ctx, _, i| { + // Load the i-th int32 in the input sequence + let int = unsafe { + input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value() + }; + + // Cast to SizeT + let int = + ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + // Store + unsafe { result.set_typed_unchecked(ctx, generator, &i, int) }; + + Ok(()) + }, + one, + ) + .unwrap(); + + result + } + + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_struct_value(), None); + + let len = input_seq.get_type().num_elements(); + + let result = generator + .gen_array_var_alloc( + ctx, + llvm_usize.into(), + llvm_usize.const_int(u64::from(len), false), + None, + ) + .unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + Box::new(|_, val| val.into_int_value()), + Box::new(|_, val| val.as_basic_value_enum()), + ); + + for i in 0..input_seq.get_type().num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_seq.load_element(ctx, i).into_int_value(); + let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + unsafe { + result.set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(u64::from(i), false), + int, + ); + } + } + + result + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + let input_int = input_seq.into_int_value(); + + let len = one; + let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + Box::new(|_, val| val.into_int_value()), + Box::new(|_, val| val.as_basic_value_enum()), + ); + let int = + ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap(); + + // Storing into result[0] + unsafe { + result.set_typed_unchecked(ctx, generator, &zero, int); + } + + result + } + + _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)), + } +}