From 140f8f8a08fbfeea4fa5ced45fad177e7d4289f3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Nov 2023 13:25:53 +0800 Subject: [PATCH] core: Implement most ndarray-creation functions --- nac3core/src/codegen/generator.rs | 12 + nac3core/src/codegen/irrt/irrt.c | 40 +- nac3core/src/codegen/irrt/mod.rs | 146 ++-- nac3core/src/codegen/mod.rs | 45 +- nac3core/src/codegen/stmt.rs | 33 +- nac3core/src/toplevel/builtins.rs | 129 +++- nac3core/src/toplevel/numpy.rs | 729 +++++++++++++++++- nac3core/src/typecheck/type_inferencer/mod.rs | 63 +- nac3standalone/demo/interpret_demo.py | 5 + nac3standalone/demo/src/ndarray.py | 33 + 10 files changed, 1130 insertions(+), 105 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 7a86c0bb6..c5c6aedba 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -92,6 +92,18 @@ pub trait CodeGenerator { gen_var(ctx, ty, name) } + /// Allocate memory for a variable and return a pointer pointing to it. + /// The default implementation places the allocations at the start of the function. + fn gen_array_var_alloc<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: BasicTypeEnum<'ctx>, + size: IntValue<'ctx>, + name: Option<&str>, + ) -> Result, String> { + gen_array_var(ctx, ty, size, name) + } + /// Return a pointer pointing to the target of the expression. fn gen_store_target<'ctx>( &mut self, diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 80e48aa2e..8b28bc1ad 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -199,27 +199,27 @@ double __nac3_j0(double x) { } uint32_t __nac3_ndarray_calc_size( - const int32_t *list_data, + const uint64_t *list_data, uint32_t list_len ) { uint32_t num_elems = 1; for (uint32_t i = 0; i < list_len; ++i) { - int32_t val = list_data[i]; + uint64_t val = list_data[i]; __builtin_assume(val >= 0); - num_elems *= (uint32_t) list_data[i]; + num_elems *= list_data[i]; } return num_elems; } uint64_t __nac3_ndarray_calc_size64( - const int32_t *list_data, + const uint64_t *list_data, uint64_t list_len ) { uint64_t num_elems = 1; for (uint64_t i = 0; i < list_len; ++i) { - int32_t val = list_data[i]; + uint64_t val = list_data[i]; __builtin_assume(val >= 0); - num_elems *= (uint64_t) list_data[i]; + num_elems *= list_data[i]; } return num_elems; } @@ -240,4 +240,32 @@ void __nac3_ndarray_init_dims64( for (uint64_t i = 0; i < shape_len; ++i) { ndarray_dims[i] = (uint64_t) shape_data[i]; } +} + +void __nac3_ndarray_calc_nd_indices( + uint32_t index, + const uint32_t* dims, + uint32_t num_dims, + uint32_t* idxs +) { + uint32_t stride = 1; + for (uint32_t dim = 0; dim < num_dims; dim++) { + uint32_t i = num_dims - dim - 1; + idxs[i] = (index / stride) % dims[i]; + stride *= dims[i]; + } +} + +void __nac3_ndarray_calc_nd_indices64( + uint64_t index, + const uint64_t* dims, + uint64_t num_dims, + uint64_t* idxs +) { + uint64_t stride = 1; + for (uint64_t dim = 0; dim < num_dims; dim++) { + uint64_t i = num_dims - dim - 1; + idxs[i] = (index / stride) % dims[i]; + stride *= dims[i]; + } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d6906c632..e2add43bd 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,6 @@ use crate::typecheck::typedef::Type; -use super::{CodeGenContext, CodeGenerator}; +use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator}; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -12,9 +12,6 @@ use inkwell::{ }; use nac3parser::ast::Expr; -#[cfg(debug_assertions)] -use inkwell::types::AnyTypeEnum; - #[must_use] pub fn load_irrt(ctx: &Context) -> Module { let bitcode_buf = MemoryBuffer::create_from_memory_range( @@ -550,62 +547,21 @@ pub fn call_j0<'ctx>( .into_float_value() } -/// Checks whether the pointer `value` refers to a `list` in LLVM. -fn assert_is_list(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_shape_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { - panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") - }; - assert_eq!(llvm_shape_ty.count_fields(), 2); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); - } - - value -} - -/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. -fn assert_is_ndarray(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_ndarray_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") - }; - - assert_eq!(llvm_ndarray_ty.count_fields(), 3); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); - let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { - unreachable!() - }; - let BasicTypeEnum::PointerType(dims) = ndarray_dims else { - panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") - }; - assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); - } - - value -} - /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the /// calculated total size. /// -/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM -/// representation of a `list`. +/// * `num_dims` - An [IntValue] containing the number of dimensions. +/// * `dims` - A [PointerValue] to an array containing the size of each dimensions. pub fn call_ndarray_calc_size<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - shape: PointerValue<'ctx>, + num_dims: IntValue<'ctx>, + dims: PointerValue<'ctx>, ) -> IntValue<'ctx> { - assert_is_list(shape); - - let llvm_i32 = ctx.ctx.i32_type(); + let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default()); let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { 32 => "__nac3_ndarray_calc_size", @@ -614,7 +570,7 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( }; let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[ - llvm_pi32.into(), + llvm_pi64.into(), llvm_usize.into(), ], false, @@ -624,30 +580,12 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); - let ( - shape_data, - shape_len, - ) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; - ctx.builder .build_call( ndarray_calc_size_fn, &[ - ctx.builder.build_load(shape_data, "").into(), - ctx.builder.build_load(shape_len, "").into(), + dims.into(), + num_dims.into(), ], "", ) @@ -721,4 +659,68 @@ pub fn call_ndarray_init_dims<'ctx, 'a>( ], "", ); +} + +pub fn call_ndarray_calc_nd_indices<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + index: IntValue<'ctx>, + ndarray: PointerValue<'ctx>, +) -> Result, String> { + assert_is_ndarray(ndarray); + + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_calc_nd_indices", + 64 => "__nac3_ndarray_calc_nd_indices64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None) + }); + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + let ndarray_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_pointer_value(); + + let indices = ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ); + + ctx.builder.build_call( + ndarray_calc_nd_indices_fn, + &[ + index.into(), + ndarray_dims.into(), + ndarray_num_dims.into(), + indices.into(), + ], + "", + ); + + Ok(indices) } \ No newline at end of file diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 21943d44d..b1836a075 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -34,6 +34,9 @@ use std::sync::{ }; use std::thread; +#[cfg(debug_assertions)] +use inkwell::types::AnyTypeEnum; + pub mod concrete_type; pub mod expr; mod generator; @@ -236,7 +239,7 @@ pub struct WorkerRegistry { static_value_store: Arc>, /// LLVM-related options for code generation. - llvm_options: CodeGenLLVMOptions, + pub llvm_options: CodeGenLLVMOptions, } impl WorkerRegistry { @@ -995,3 +998,43 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } + +/// Checks whether the pointer `value` refers to a `list` in LLVM. +fn assert_is_list(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_shape_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { + panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") + }; + assert_eq!(llvm_shape_ty.count_fields(), 2); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); + } + + value +} + +/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. +fn assert_is_ndarray(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") + }; + + assert_eq!(llvm_ndarray_ty.count_fields(), 3); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); + let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { + unreachable!() + }; + let BasicTypeEnum::PointerType(dims) = ndarray_dims else { + panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") + }; + assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); + } + + value +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 111fb6c64..1cf57b298 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -15,7 +15,7 @@ use crate::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, - types::BasicTypeEnum, + types::{BasicType, BasicTypeEnum}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, }; @@ -54,6 +54,37 @@ pub fn gen_var<'ctx>( Ok(ptr) } +/// See [CodeGenerator::gen_array_var_alloc]. +pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>( + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: T, + size: IntValue<'ctx>, + name: Option<&str>, +) -> Result, String> { + // Restore debug location + let di_loc = ctx.debug_info.0.create_debug_location( + ctx.ctx, + ctx.current_loc.row as u32, + ctx.current_loc.column as u32, + ctx.debug_info.2, + None, + ); + + // put the alloca in init block + let current = ctx.builder.get_insert_block().unwrap(); + + // position before the last branching instruction... + ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap()); + ctx.builder.set_current_debug_location(di_loc); + + let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or("")); + + ctx.builder.position_at_end(current); + ctx.builder.set_current_debug_location(di_loc); + + Ok(ptr) +} + /// See [`CodeGenerator::gen_store_target`]. pub fn gen_store_target<'ctx, G: CodeGenerator>( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index cb3f650f3..d2eb458ae 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -13,7 +13,13 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::numpy::gen_ndarray_empty, + toplevel::numpy::{ + gen_ndarray_empty, + gen_ndarray_eye, + gen_ndarray_full, + gen_ndarray_ones, + gen_ndarray_zeros, + }, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -22,6 +28,7 @@ use inkwell::{ FloatPredicate, IntPredicate }; +use crate::toplevel::numpy::gen_ndarray_identity; type BuiltinInfo = Vec<(Arc>, Option)>; @@ -279,10 +286,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; + let ndarray = { + let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0); + primitives.1.add_ty(ndarray_ty) + }; let ndarray_float = { let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); primitives.1.add_ty(ndarray_ty_enum) }; + let ndarray_float_2d = { + let value = match primitives.0.size_t { + 64 => SymbolValue::U64(2u64), + 32 => SymbolValue::U32(2u32), + _ => unreachable!(), + }; + let ndims = primitives.1.add_ty(TypeEnum::TLiteral { + values: vec![value], + loc: None, + }); + + primitives.1.add_ty(TypeEnum::TNDArray { + ty: float, + ndims, + }) + }; let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = primitives.1.get_fresh_var_with_range( &[int32, int64, float, boolean, uint32, uint64], @@ -869,6 +896,89 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .map(|val| Some(val.as_basic_value_enum())) }), ), + create_fn_by_codegen( + primitives, + &var_map, + "np_zeros", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_zeros(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "np_ones", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_ones(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + { + let tv = primitives.1.get_fresh_var(Some("T".into()), None).0; + + create_fn_by_codegen( + primitives, + &var_map, + "np_full", + ndarray, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape"), (tv, "fill_value")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_full(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ) + }, + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_eye".into(), + simple_name: "np_eye".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "N".into(), ty: int32, default_value: None }, + // TODO(Derppening): Default values current do not work? + FuncArg { + name: "M".into(), + ty: int32, + default_value: Some(SymbolValue::OptionNone) + }, + FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) }, + ], + ret: ndarray_float_2d, + vars: var_map.clone(), + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_eye(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + })), + create_fn_by_codegen( + primitives, + &var_map, + "np_identity", + ndarray_float_2d, + &[(int32, "n")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_identity(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), create_fn_by_codegen( primitives, &var_map, @@ -1364,7 +1474,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) } } - TypeEnum::TNDArray { .. } => todo!(), + TypeEnum::TNDArray { .. } => { + let llvm_i32 = ctx.ctx.i32_type(); + let i32_zero = llvm_i32.const_zero(); + + let len = ctx.build_gep_and_load( + arg.into_pointer_value(), + &[i32_zero, i32_zero], + None, + ).into_int_value(); + + if len.get_type().get_bit_width() != 32 { + Some(ctx.builder.build_int_truncate(len, llvm_i32, "len").into()) + } else { + Some(len.into()) + } + } _ => unreachable!(), } }) diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 9b91e826e..13bb8a537 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,14 +1,15 @@ -use inkwell::{ - IntPredicate, - types::BasicType, - values::PointerValue, -}; +use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; +use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ CodeGenContext, CodeGenerator, - irrt::{call_ndarray_calc_size, call_ndarray_init_dims}, + irrt::{ + call_ndarray_calc_nd_indices, + call_ndarray_calc_size, + call_ndarray_init_dims, + }, stmt::gen_for_callback }, symbol_resolver::ValueEnum, @@ -16,16 +17,201 @@ use crate::{ typecheck::typedef::{FunSignature, Type, TypeEnum}, }; -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// Creates an `NDArray` instance from a constant shape. /// -/// * `elem_ty` - The element type of the NDArray. -/// * `var_name` - The variable name of the NDArray. -/// * `shape` - The `shape` parameter used to construct the NDArray. -fn call_ndarray_impl<'ctx, 'a>( +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue]. +fn create_ndarray_const_shape<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: ArrayValue<'ctx> +) -> Result, String> { + let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); + let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + + for i in 0..shape.get_type().len() { + let shape_dim = ctx.builder.build_extract_value( + shape, + i, + "", + ).unwrap(); + + let shape_dim_gez = ctx.builder.build_int_compare( + IntPredicate::SGE, + shape_dim.into_int_value(), + llvm_usize.const_zero(), + "" + ); + + ctx.make_assert( + generator, + shape_dim_gez, + "0:ValueError", + "negative dimensions not supported", + [None, None, None], + ctx.current_loc, + ); + } + + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; + + let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false); + + let ndarray_num_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "", + ) + }; + ctx.builder.build_store(ndarray_num_dims, num_dims); + + let ndarray_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "", + ) + }; + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_store( + ndarray_dims, + ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ), + ); + + for i in 0..shape.get_type().len() { + let ndarray_dim = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_pointer_value(); + let ndarray_dim = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray_dim, + &[llvm_i32.const_int(i as u64, true)], + "", + ) + }; + let shape_dim = ctx.builder.build_extract_value(shape, i, "") + .map(|val| val.into_int_value()) + .unwrap(); + + ctx.builder.build_store(ndarray_dim, shape_dim); + } + + let (ndarray_num_dims, ndarray_dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), + ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), + ); + + let ndarray_data = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + "", + ) + }; + ctx.builder.build_store( + ndarray_data, + ctx.builder.build_array_alloca( + llvm_ndarray_data_t, + ndarray_num_elems, + "" + ), + ); + + Ok(ndarray) +} + +fn ndarray_zero_value<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + unreachable!() + } +} + +fn ndarray_one_value<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + unreachable!() + } +} + +/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_empty_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - var_name: Option<&str>, shape: PointerValue<'ctx>, ) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); @@ -43,8 +229,8 @@ fn call_ndarray_impl<'ctx, 'a>( gen_for_callback( generator, ctx, - |_, ctx| { - let i = ctx.builder.build_alloca(llvm_usize, ""); + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; ctx.builder.build_store(i, llvm_usize.const_zero()); Ok(i) @@ -106,10 +292,11 @@ fn call_ndarray_impl<'ctx, 'a>( }, )?; - let ndarray = ctx.builder.build_alloca( - llvm_ndarray_t, - var_name.unwrap_or_default() - ); + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; let num_dims = ctx.build_gep_and_load( shape, @@ -151,7 +338,26 @@ fn call_ndarray_impl<'ctx, 'a>( call_ndarray_init_dims(generator, ctx, ndarray, shape); - let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape); + let (ndarray_num_dims, ndarray_dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), + ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), + ); let ndarray_data = unsafe { ctx.builder.build_in_bounds_gep( @@ -172,6 +378,342 @@ fn call_ndarray_impl<'ctx, 'a>( Ok(ndarray) } +/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as +/// its input. +/// +/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements +/// with the given value (as opposed to all elements within the array). +fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + value_fn: ValueFn, +) -> Result<(), String> + where + ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result, String>, +{ + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (num_dims, dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(num_dims, "").into_int_value(), + ctx.builder.build_load(dims, "").into_pointer_value(), + ); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "")) + }, + |generator, ctx, i_addr| { + let ndarray_data = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + None + ).into_pointer_value(); + + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let elem = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray_data, + &[i], + "" + ) + }; + + let value = value_fn(generator, ctx, i)?; + ctx.builder.build_store(elem, value); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + ) +} + +/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices +/// as its input +/// +/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements +/// with the given value (as opposed to all elements within the array). +fn ndarray_fill_indexed<'ctx, 'a, ValueFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + value_fn: ValueFn, +) -> Result<(), String> + where + ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result, String>, +{ + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, idx| { + let indices = call_ndarray_calc_nd_indices( + generator, + ctx, + idx, + ndarray, + )?; + + value_fn(generator, ctx, indices) + } + ) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_zeros_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, +) -> Result, String> { + let supported_types = [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ctx.primitives.float, + ctx.primitives.bool, + ctx.primitives.str, + ]; + assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); + + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = ndarray_zero_value(generator, ctx, elem_ty); + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_ones_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, +) -> Result, String> { + let supported_types = [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ctx.primitives.float, + ctx.primitives.bool, + ctx.primitives.str, + ]; + assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); + + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = ndarray_one_value(generator, ctx, elem_ty); + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_full_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, + fill_value: BasicValueEnum<'ctx>, +) -> Result, String> { + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = if fill_value.is_pointer_value() { + let llvm_void = ctx.ctx.void_type(); + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; + + let memcpy_fn_name = format!( + "llvm.memcpy.p0i8.p0i8.i{}", + generator.get_size_type(ctx.ctx).get_bit_width(), + ); + let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_pi8.into(), + llvm_pi8.into(), + llvm_usize.into(), + llvm_i1.into(), + ], + false, + ); + + ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None) + }); + + ctx.builder.build_call( + memcpy_fn, + &[ + copy.into(), + fill_value.into(), + fill_value.get_type().size_of().unwrap().into(), + llvm_i1.const_zero().into(), + ], + "", + ); + + copy.into() + } else if fill_value.is_int_value() || fill_value.is_float_value() { + fill_value.into() + } else { + unreachable!() + }; + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. +/// +/// * `elem_ty` - The element type of the NDArray. +fn call_ndarray_eye_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize_2 = llvm_usize.array_type(2); + + let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?; + + let shape = ctx.builder.build_load(shape_addr, "") + .into_array_value(); + + let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, ""); + let shape = ctx.builder + .build_insert_value(shape, nrows, 0, "") + .map(|val| val.into_array_value()) + .unwrap(); + + let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, ""); + let shape = ctx.builder + .build_insert_value(shape, ncols, 1, "") + .map(|val| val.into_array_value()) + .unwrap(); + + let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?; + + ndarray_fill_indexed( + generator, + ctx, + ndarray, + |generator, ctx, indices| { + let row = ctx.build_gep_and_load( + indices, + &[llvm_i32.const_zero()], + None, + ).into_int_value(); + let col = ctx.build_gep_and_load( + indices, + &[llvm_i32.const_int(1, true)], + None, + ).into_int_value(); + + let col_with_offset = ctx.builder.build_int_add( + col, + ctx.builder.build_int_z_extend_or_bit_cast(offset, llvm_usize, ""), + "" + ); + let is_on_diag = ctx.builder.build_int_compare( + IntPredicate::EQ, + row, + col_with_offset, + "" + ); + + let zero = ndarray_zero_value(generator, ctx, elem_ty); + let one = ndarray_one_value(generator, ctx, elem_ty); + + let value = ctx.builder.build_select(is_on_diag, one, zero, ""); + + Ok(value) + }, + )?; + + Ok(ndarray) +} + /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx, 'a>( context: &mut CodeGenContext<'ctx, 'a>, @@ -184,15 +726,158 @@ pub fn gen_ndarray_empty<'ctx, 'a>( assert_eq!(args.len(), 1); let shape_ty = fun.0.args[0].ty; - let shape_arg_name = args[0].0; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_impl( + call_ndarray_empty_impl( generator, context, context.primitives.float, - shape_arg_name.map(|name| name.to_string()).as_deref(), shape_arg.into_pointer_value(), ) +} + +/// Generates LLVM IR for `ndarray.zeros`. +pub fn gen_ndarray_zeros<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_zeros_impl( + generator, + context, + context.primitives.float, + shape_arg.into_pointer_value(), + ) +} + +/// Generates LLVM IR for `ndarray.ones`. +pub fn gen_ndarray_ones<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_ones_impl( + generator, + context, + context.primitives.float, + shape_arg.into_pointer_value(), + ) +} + +/// Generates LLVM IR for `ndarray.full`. +pub fn gen_ndarray_full<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 2); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + let fill_value_ty = fun.0.args[1].ty; + let fill_value_arg = args[1].1.clone() + .to_basic_value_enum(context, generator, fill_value_ty)?; + + call_ndarray_full_impl( + generator, + context, + fill_value_ty, + shape_arg.into_pointer_value(), + fill_value_arg, + ) +} + +/// Generates LLVM IR for `ndarray.eye`. +pub fn gen_ndarray_eye<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert!(matches!(args.len(), 1..=3)); + + let nrows_ty = fun.0.args[0].ty; + let nrows_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, nrows_ty)?; + + let ncols_ty = fun.0.args[1].ty; + let ncols_arg = args.iter() + .find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false)) + .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)) + .unwrap_or_else(|| { + args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty) + })?; + + let offset_ty = fun.0.args[2].ty; + let offset_arg = args.iter() + .find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false)) + .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty)) + .unwrap_or_else(|| { + Ok(context.gen_symbol_val( + generator, + fun.0.args[2].default_value.as_ref().unwrap(), + offset_ty + )) + })?; + + call_ndarray_eye_impl( + generator, + context, + context.primitives.float, + nrows_arg.into_int_value(), + ncols_arg.into_int_value(), + offset_arg.into_int_value(), + ) +} + +/// Generates LLVM IR for `ndarray.identity`. +pub fn gen_ndarray_identity<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let llvm_usize = generator.get_size_type(context.ctx); + + let n_ty = fun.0.args[0].ty; + let n_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, n_ty)?; + + call_ndarray_eye_impl( + generator, + context, + context.primitives.float, + n_arg.into_int_value(), + n_arg.into_int_value(), + llvm_usize.const_zero(), + ) } \ No newline at end of file diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index cfe000fa6..1462ca65d 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -898,9 +898,14 @@ impl<'a> Inferencer<'a> { if [ "np_ndarray".into(), "np_empty".into(), + "np_zeros".into(), + "np_ones".into(), ].contains(id) && args.len() == 1 { let ExprKind::List { elts, .. } = &args[0].node else { - return report_error("Expected List literal for first argument of np_ndarray", args[0].location) + return report_error( + format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), + args[0].location + ) }; let ndims = elts.len() as u64; @@ -941,6 +946,62 @@ impl<'a> Inferencer<'a> { })) } + // 2-argument ndarray n-dimensional creation functions + if id == &"np_full".into() && args.len() == 2 { + let ExprKind::List { elts, .. } = &args[0].node else { + return report_error( + format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), + args[0].location + ) + }; + + let ndims = elts.len() as u64; + + let arg0 = self.fold_expr(args.remove(0))?; + let arg1 = self.fold_expr(args.remove(0))?; + + let ty = arg1.custom.unwrap(); + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + + let ret = self.unifier.add_ty(TypeEnum::TNDArray { + ty, + ndims + }); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "shape".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + FuncArg { + name: "fill_value".into(), + ty: arg1.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: HashMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0, arg1], + keywords: vec![], + }, + })) + } + Ok(None) } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index abdeda971..03deff455 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -187,6 +187,11 @@ def patch(module): # NumPy NDArray Functions module.np_ndarray = np.ndarray module.np_empty = np.empty + module.np_zeros = np.zeros + module.np_ones = np.ones + module.np_full = np.full + module.np_eye = np.eye + module.np_identity = np.identity def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 1237d0692..1ab153b93 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -7,6 +7,12 @@ def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]): def consume_ndarray_2(n: ndarray[float, Literal[2]]): pass +def consume_ndarray_i32_1(n: ndarray[int32, 1]): + pass + +def consume_ndarray_2(n: ndarray[float, 2]): + pass + def test_ndarray_ctor(): n = np_ndarray([1]) consume_ndarray_1(n) @@ -15,8 +21,35 @@ def test_ndarray_empty(): n = np_empty([1]) consume_ndarray_1(n) +def test_ndarray_zeros(): + n = np_zeros([1]) + consume_ndarray_1(n) + +def test_ndarray_ones(): + n = np_ones([1]) + consume_ndarray_1(n) + +def test_ndarray_full(): + n_float = np_full([1], 2.0) + consume_ndarray_1(n_float) + n_i32 = np_full([1], 2) + consume_ndarray_i32_1(n_i32) + +def test_ndarray_eye(): + n = np_eye(2) + consume_ndarray_2(n) + +def test_ndarray_identity(): + n = np_identity(2) + consume_ndarray_2(n) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() + test_ndarray_zeros() + test_ndarray_ones() + test_ndarray_full() + test_ndarray_eye() + test_ndarray_identity() return 0