From 3a241acc9cf1c478912fcbd7702ab72147546f83 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 12:25:01 +0800 Subject: [PATCH] core/ndstrides: implement np_{zeros,ones,full,empty} --- nac3core/src/codegen/numpy.rs | 43 ++++-- .../src/codegen/object/ndarray/factory.rs | 126 ++++++++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 26 +++- .../src/codegen/object/ndarray/shape_util.rs | 105 +++++++++++++++ 4 files changed, 290 insertions(+), 10 deletions(-) create mode 100644 nac3core/src/codegen/object/ndarray/factory.rs create mode 100644 nac3core/src/codegen/object/ndarray/shape_util.rs diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d58b566b..5cb4b4be 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -12,12 +12,16 @@ use crate::{ call_ndarray_calc_size, }, llvm_intrinsics::{self, call_memcpy_generic}, + object::{ + any::AnyObject, + ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, + }, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, @@ -1742,8 +1746,13 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_empty(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1760,8 +1769,13 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_zeros(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.ones`. @@ -1778,8 +1792,13 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_ones(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.full`. @@ -1799,8 +1818,14 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = + NDArrayObject::make_np_full(generator, context, dtype, ndims, shape, fill_value_arg); + Ok(ndarray.instance.value) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs new file mode 100644 index 00000000..a1c0af78 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -0,0 +1,126 @@ +use inkwell::values::BasicValueEnum; + +use crate::{ + codegen::{ + irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext, + CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::NDArrayObject; + +/// Get the zero value in `np.zeros()` of a `dtype`. +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create an ndarray like `np.empty`. + pub fn make_np_empty( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + // Validate `shape` + let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims); + call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.create_data(generator, ctx); + + ndarray + } + + /// Create an ndarray like `np.full`. + pub fn make_np_full( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + fill_value: BasicValueEnum<'ctx>, + ) -> Self { + let ndarray = NDArrayObject::make_np_empty(generator, ctx, dtype, ndims, shape); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like `np.zero`. + pub fn make_np_zeros( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_zero_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } + + /// Create an ndarray like `np.ones`. + pub fn make_np_ones( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_one_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index c1b5f67d..bb289b31 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,6 +1,13 @@ +pub mod factory; pub mod nditer; +pub mod shape_util; -use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace}; +use inkwell::{ + context::Context, + types::BasicType, + values::{BasicValueEnum, PointerValue}, + AddressSpace, +}; use crate::{ codegen::{ @@ -345,4 +352,21 @@ impl<'ctx> NDArrayObject<'ctx> { assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); } + + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } } diff --git a/nac3core/src/codegen/object/ndarray/shape_util.rs b/nac3core/src/codegen/object/ndarray/shape_util.rs new file mode 100644 index 00000000..aa6c3f80 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/shape_util.rs @@ -0,0 +1,105 @@ +use util::gen_for_model; + +use crate::{ + codegen::{ + model::*, + object::{any::AnyObject, list::ListObject, tuple::TupleObject}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::TypeEnum, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + input_sequence: AnyObject<'ctx>, +) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let one = Int(SizeT).const_1(generator, ctx.ctx); + + // The result `list` to return. + match &*ctx.unifier.get_ty(input_sequence.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + // Check `input_sequence` + let input_sequence = ListObject::from_object(generator, ctx, input_sequence); + + let len = input_sequence.instance.get(generator, ctx, |f| f.len); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_model(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| { + // Load the i-th int32 in the input sequence + let int = input_sequence + .instance + .get(generator, ctx, |f| f.items) + .get_index(generator, ctx, i.value) + .value + .into_int_value(); + + // Cast to SizeT + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + // Store + result.set_index(ctx, i.value, int); + + Ok(()) + }) + .unwrap(); + + (len, result) + } + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_sequence = TupleObject::from_object(ctx, input_sequence); + + let len = input_sequence.len(generator, ctx); + + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + for i in 0..input_sequence.num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_sequence.index(ctx, i).value.into_int_value(); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + result.set_index_const(ctx, i as u64, int); + } + + (len, result) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + let input_int = input_sequence.value.into_int_value(); + + let len = Int(SizeT).const_1(generator, ctx.ctx); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, input_int); + + // Storing into result[0] + result.store(ctx, int); + + (len, result) + } + _ => panic!( + "encountered unknown sequence type: {}", + ctx.unifier.stringify(input_sequence.ty) + ), + } +}