From 6153f94b0525ad616e5df923f2537ecf9f9e0340 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 11 Jun 2024 15:29:32 +0800 Subject: [PATCH] core/numpy: Implement codegen for np_array --- nac3core/src/codegen/numpy.rs | 471 +++++++++++++++++++++++++++++- nac3core/src/toplevel/builtins.rs | 7 +- 2 files changed, 473 insertions(+), 5 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9fff4259..1e685a77 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,12 +1,16 @@ -use inkwell::{IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}}; +use inkwell::{AddressSpace, IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}}; +use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use nac3parser::ast::{Operator, StrRef}; use crate::{ codegen::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, + ListType, ListValue, + NDArrayType, NDArrayValue, + ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, @@ -31,9 +35,10 @@ use crate::{ symbol_resolver::ValueEnum, toplevel::{ DefinitionId, + helper::PRIMITIVE_DEF_IDS, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, - typecheck::typedef::{FunSignature, Type}, + typecheck::typedef::{FunSignature, Type, TypeEnum}, }; /// Creates an uninitialized `NDArray` instance. @@ -589,6 +594,405 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } +/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. +fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ty: PointerType<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let list_ty = ListType::from_type(ty, llvm_usize); + let list_elem_ty = list_ty.element_type(); + + let ndims = llvm_usize.const_int(1, false); + match list_elem_ty { + AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { + ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) + } + + AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { + todo!("Getting ndims for list[ndarray] not supported") + } + + _ => ndims, + } +} + +/// Returns the number of dimensions for an array-like object as an [`IntValue`]. +fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + match value { + BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { + NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) + } + + BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { + llvm_ndlist_get_ndims(generator, ctx, v.get_type()) + } + + _ => llvm_usize.const_zero(), + } +} + +/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. +fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), + src_lst: ListValue<'ctx>, + dim: u64, +) -> Result<(), String> { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let list_elem_ty = src_lst.get_type().element_type(); + + match list_elem_ty { + AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { + // The stride of elements in this dimension, i.e. the number of elements between arr[i] + // and arr[i + 1] in this dimension + let stride = call_ndarray_calc_size( + generator, + ctx, + &dst_arr.dim_sizes(), + (Some(llvm_usize.const_int(dim + 1, false)), None), + ); + + gen_for_range_callback( + generator, + ctx, + true, + |_, _| Ok(llvm_usize.const_zero()), + (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), + |_, _| Ok(llvm_usize.const_int(1, false)), + |generator, ctx, i| { + let offset = ctx.builder.build_int_mul( + stride, + i, + "", + ).unwrap(); + + let dst_ptr = unsafe { + ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() + }; + + let nested_lst_elem = ListValue::from_ptr_val( + unsafe { + src_lst.data().get_unchecked(ctx, generator, &i, None) + }.into_pointer_value(), + llvm_usize, + None, + ); + + ndarray_from_ndlist_impl( + generator, + ctx, + elem_ty, + (dst_arr, dst_ptr), + nested_lst_elem, + dim + 1, + )?; + + Ok(()) + }, + )?; + } + + AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { + todo!("Not implemented for list[ndarray]") + } + + _ => { + let lst_len = src_lst.load_size(ctx, None); + let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); + let cpy_len = ctx.builder.build_int_mul( + ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), + sizeof_elem, + "" + ).unwrap(); + + call_memcpy_generic( + ctx, + dst_slice_ptr, + src_lst.data().base_ptr(ctx, generator), + cpy_len, + llvm_i1.const_zero(), + ); + } + } + + Ok(()) +} + +/// LLVM-typed implementation for `ndarray.array`. +fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + object: BasicValueEnum<'ctx>, + copy: IntValue<'ctx>, + ndmin: IntValue<'ctx>, +) -> Result, String> { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let ndmin = ctx.builder + .build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "") + .unwrap(); + + // TODO(Derppening): Add assertions for sizes of different dimensions + + // object is not a pointer - 0-dim NDArray + if !object.is_pointer_value() { + let ndarray = create_ndarray_const_shape( + generator, + ctx, + elem_ty, + &[], + )?; + + unsafe { + ndarray.data() + .set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); + } + + return Ok(ndarray) + } + + let object = object.into_pointer_value(); + + // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims + if NDArrayValue::is_instance(object, llvm_usize).is_ok() { + let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + let copy_nez = ctx.builder + .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") + .unwrap(); + let ndmin_gt_ndims = ctx.builder + .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") + .unwrap(); + + Ok(ctx.builder + .build_and(copy_nez, ndmin_gt_ndims, "") + .unwrap()) + }, + |generator, ctx| { + let ndarray = create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &object, + |_, ctx, object| { + let ndims = object.load_ndims(ctx); + let ndmin_gt_ndims = ctx.builder + .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") + .unwrap(); + + Ok(ctx.builder + .build_select(ndmin_gt_ndims, ndmin, ndims, "") + .map(BasicValueEnum::into_int_value) + .unwrap()) + }, + |generator, ctx, object, idx| { + let ndims = object.load_ndims(ctx); + let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); + // The number of dimensions to prepend 1's to + let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); + + Ok(gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx.builder + .build_int_compare(IntPredicate::UGE, idx, offset, "") + .unwrap()) + }, + |_, _| { + Ok(Some(llvm_usize.const_int(1, false))) + }, + |_, ctx| { + Ok(Some(ctx.builder.build_int_sub( + idx, + offset, + "" + ).unwrap())) + }, + )?.map(BasicValueEnum::into_int_value).unwrap()) + }, + )?; + + ndarray_sliced_copyto_impl( + generator, + ctx, + elem_ty, + (ndarray, ndarray.data().base_ptr(ctx, generator)), + (object, object.data().base_ptr(ctx, generator)), + 0, + &[], + )?; + + Ok(Some(ndarray.as_base_value())) + }, + |_, _| { + Ok(Some(object.as_base_value())) + }, + )?; + + return Ok(NDArrayValue::from_ptr_val( + ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), + llvm_usize, + None, + )) + } + + // Remaining case: TList + assert!(ListValue::is_instance(object, llvm_usize).is_ok()); + let object = ListValue::from_ptr_val(object, llvm_usize, None); + + // The number of dimensions to prepend 1's to + let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); + let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); + let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); + + let ndarray = create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &object, + |generator, ctx, object| { + let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); + let ndmin_gt_ndims = ctx.builder + .build_int_compare(IntPredicate::UGT, ndmin, ndims, "") + .unwrap(); + + Ok(ctx.builder + .build_select(ndmin_gt_ndims, ndmin, ndims, "") + .map(BasicValueEnum::into_int_value) + .unwrap()) + }, + |generator, ctx, object, idx| { + Ok(gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx.builder + .build_int_compare(IntPredicate::ULT, idx, offset, "") + .unwrap()) + }, + |_, _| { + Ok(Some(llvm_usize.const_int(1, false))) + }, + |generator, ctx| { + let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { + ctx.ctx.struct_type( + &[ + elem_ty.ptr_type(AddressSpace::default()).into(), + llvm_usize.into(), + ], + false, + ) + }; + + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_list_i8 = make_llvm_list(llvm_i8.into()); + let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); + + // Cast list to { i8*, usize } since we only care about the size + let lst = generator.gen_var_alloc( + ctx, + ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), + None, + ).unwrap(); + ctx.builder.build_store( + lst, + ctx.builder.build_bitcast( + object.as_base_value(), + llvm_plist_i8, + "", + ).unwrap(), + ).unwrap(); + + let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); + gen_for_range_callback( + generator, + ctx, + true, + |_, _| Ok(llvm_usize.const_zero()), + (|_, _| Ok(stop), false), + |_, _| Ok(llvm_usize.const_int(1, false)), + |generator, ctx, _| { + let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) + .ptr_type(AddressSpace::default()); + + let this_dim = ctx.builder + .build_load(lst, "") + .map(BasicValueEnum::into_pointer_value) + .map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap()) + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + let this_dim = ListValue::from_ptr_val( + this_dim, + llvm_usize, + None, + ); + + // TODO: Assert this_dim.sz != 0 + + let next_dim = unsafe { + this_dim.data() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }.into_pointer_value(); + ctx.builder.build_store( + lst, + ctx.builder.build_bitcast( + next_dim, + llvm_plist_i8, + "", + ).unwrap(), + ).unwrap(); + + Ok(()) + }, + )?; + + let lst = ListValue::from_ptr_val( + ctx.builder + .build_load(lst, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(), + llvm_usize, + None, + ); + + Ok(Some(lst.load_size(ctx, None))) + }, + )?.map(BasicValueEnum::into_int_value).unwrap()) + }, + )?; + + ndarray_from_ndlist_impl( + generator, + ctx, + elem_ty, + (ndarray, ndarray.data().base_ptr(ctx, generator)), + object, + 0, + )?; + + Ok(ndarray) +} + /// LLVM-typed implementation for generating the implementation for `ndarray.eye`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1450,6 +1854,69 @@ pub fn gen_ndarray_full<'ctx>( ).map(NDArrayValue::into) } +pub fn gen_ndarray_array<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert!(matches!(args.len(), 1..=3)); + + let obj_ty = fun.0.args[0].ty; + let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 + } + + TypeEnum::TList { ty } => { + let mut ty = *ty; + while let TypeEnum::TList { ty: elem_ty } = &*context.unifier.get_ty_immutable(ty) { + ty = *elem_ty; + } + ty + }, + + _ => obj_ty, + }; + let obj_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, obj_ty)?; + + let copy_arg = if let Some(arg) = + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) { + let copy_ty = fun.0.args[1].ty; + arg.1.clone().to_basic_value_enum(context, generator, copy_ty)? + } else { + context.gen_symbol_val( + generator, + fun.0.args[1].default_value.as_ref().unwrap(), + fun.0.args[1].ty, + ) + }; + + let ndmin_arg = if let Some(arg) = + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) { + let ndmin_ty = fun.0.args[2].ty; + arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? + } else { + context.gen_symbol_val( + generator, + fun.0.args[2].default_value.as_ref().unwrap(), + fun.0.args[2].ty, + ) + }; + + call_ndarray_array_impl( + generator, + context, + obj_elem_ty, + obj_arg, + copy_arg.into_int_value(), + ndmin_arg.into_int_value(), + ).map(NDArrayValue::into) +} + /// Generates LLVM IR for `ndarray.eye`. pub fn gen_ndarray_eye<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c169c603..5e1c9868 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -809,15 +809,16 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, ], ret: ndarray, - vars: VarMap::default(), + vars: VarMap::from([(tv.1, tv.0)]), })), - var_id: Vec::default(), + var_id: vec![tv.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( |ctx, obj, fun, args, generator| { - todo!() + gen_ndarray_array(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) }, )))), loc: None,