From 27fcf8926ef17d70b2cd470d18d8c8fc5ec21353 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 17 Nov 2023 17:30:27 +0800 Subject: [PATCH] core: Implement ndarray constructor and numpy.empty --- nac3core/src/codegen/irrt/irrt.c | 44 ++++ nac3core/src/codegen/irrt/mod.rs | 176 ++++++++++++++++ nac3core/src/codegen/stmt.rs | 76 ++++++- nac3core/src/toplevel/builtins.rs | 34 ++- nac3core/src/toplevel/mod.rs | 1 + nac3core/src/toplevel/numpy.rs | 198 ++++++++++++++++++ nac3core/src/typecheck/type_inferencer/mod.rs | 49 ++++- nac3standalone/demo/interpret_demo.py | 23 +- nac3standalone/demo/src/ndarray.py | 22 ++ 9 files changed, 619 insertions(+), 4 deletions(-) create mode 100644 nac3core/src/toplevel/numpy.rs create mode 100644 nac3standalone/demo/src/ndarray.py diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index d68b3446..80e48aa2 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -196,4 +196,48 @@ double __nac3_j0(double x) { } return j0(x); +} + +uint32_t __nac3_ndarray_calc_size( + const int32_t *list_data, + uint32_t list_len +) { + uint32_t num_elems = 1; + for (uint32_t i = 0; i < list_len; ++i) { + int32_t val = list_data[i]; + __builtin_assume(val >= 0); + num_elems *= (uint32_t) list_data[i]; + } + return num_elems; +} + +uint64_t __nac3_ndarray_calc_size64( + const int32_t *list_data, + uint64_t list_len +) { + uint64_t num_elems = 1; + for (uint64_t i = 0; i < list_len; ++i) { + int32_t val = list_data[i]; + __builtin_assume(val >= 0); + num_elems *= (uint64_t) list_data[i]; + } + return num_elems; +} + +void __nac3_ndarray_init_dims( + uint32_t *ndarray_dims, + const int32_t *shape_data, + uint32_t shape_len +) { + __builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t)); +} + +void __nac3_ndarray_init_dims64( + uint64_t *ndarray_dims, + const int32_t *shape_data, + uint64_t shape_len +) { + for (uint64_t i = 0; i < shape_len; ++i) { + ndarray_dims[i] = (uint64_t) shape_data[i]; + } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 0cd0ebdd..d6906c63 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -12,6 +12,9 @@ use inkwell::{ }; use nac3parser::ast::Expr; +#[cfg(debug_assertions)] +use inkwell::types::AnyTypeEnum; + #[must_use] pub fn load_irrt(ctx: &Context) -> Module { let bitcode_buf = MemoryBuffer::create_from_memory_range( @@ -546,3 +549,176 @@ pub fn call_j0<'ctx>( .unwrap_left() .into_float_value() } + +/// Checks whether the pointer `value` refers to a `list` in LLVM. +fn assert_is_list(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_shape_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { + panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") + }; + assert_eq!(llvm_shape_ty.count_fields(), 2); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); + } + + value +} + +/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. +fn assert_is_ndarray(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") + }; + + assert_eq!(llvm_ndarray_ty.count_fields(), 3); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); + let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { + unreachable!() + }; + let BasicTypeEnum::PointerType(dims) = ndarray_dims else { + panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") + }; + assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); + } + + value +} + +/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the +/// calculated total size. +/// +/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM +/// representation of a `list`. +pub fn call_ndarray_calc_size<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + shape: PointerValue<'ctx>, +) -> IntValue<'ctx> { + assert_is_list(shape); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + + let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_calc_size", + 64 => "__nac3_ndarray_calc_size64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_size_fn_t = llvm_usize.fn_type( + &[ + llvm_pi32.into(), + llvm_usize.into(), + ], + false, + ); + let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name) + .unwrap_or_else(|| { + ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + }); + + let ( + shape_data, + shape_len, + ) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + + ctx.builder + .build_call( + ndarray_calc_size_fn, + &[ + ctx.builder.build_load(shape_data, "").into(), + ctx.builder.build_load(shape_len, "").into(), + ], + "", + ) + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} + +/// Generates a call to `__nac3_ndarray_init_dims`. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM +/// representation of a `list`. +pub fn call_ndarray_init_dims<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + shape: PointerValue<'ctx>, +) { + assert_is_ndarray(ndarray); + assert_is_list(shape); + + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_init_dims", + 64 => "__nac3_ndarray_init_dims64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_pusize.into(), + llvm_pi32.into(), + llvm_usize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None) + }); + + let ndarray_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ); + let shape_data = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None + ); + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_call( + ndarray_init_dims_fn, + &[ + ndarray_dims.into(), + shape_data.into(), + ndarray_num_dims.into(), + ], + "", + ); +} \ No newline at end of file diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index e53ea325..111fb6c6 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -16,7 +16,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, + values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, }; use nac3parser::ast::{ @@ -405,6 +405,80 @@ pub fn gen_for( Ok(()) } +/// Generates a C-style `for` construct using lambdas, similar to the following C code: +/// +/// ```c +/// for (x... = init(); cond(x...); update(x...)) { +/// body(x...); +/// } +/// ``` +/// +/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The +/// return value is a [Clone] value which will be passed to the other lambdas. +/// * `cond` - A lambda containing IR statements checking whether the loop should continue +/// executing. The result value must be an `i1` indicating if the loop should continue. +/// * `body` - A lambda containing IR statements within the loop body. +/// * `update` - A lambda containing IR statements updating loop variables. +pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: InitFn, + cond: CondFn, + body: BodyFn, + update: UpdateFn, +) -> Result<(), String> + where + I: Clone, + InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result, + CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, + BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, + UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, +{ + let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap(); + let init_bb = ctx.ctx.append_basic_block(current, "for.init"); + // The BB containing the loop condition check + let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); + let body_bb = ctx.ctx.append_basic_block(current, "for.body"); + // The BB containing the increment expression + let update_bb = ctx.ctx.append_basic_block(current, "for.update"); + let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); + + // store loop bb information and restore it later + let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); + + ctx.builder.build_unconditional_branch(init_bb); + + let loop_var = { + ctx.builder.position_at_end(init_bb); + let result = init(generator, ctx)?; + ctx.builder.build_unconditional_branch(cond_bb); + + result + }; + + ctx.builder.position_at_end(cond_bb); + let cond = cond(generator, ctx, loop_var.clone())?; + assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width()); + ctx.builder.build_conditional_branch( + cond, + body_bb, + cont_bb + ); + + ctx.builder.position_at_end(body_bb); + body(generator, ctx, loop_var.clone())?; + ctx.builder.build_unconditional_branch(update_bb); + + ctx.builder.position_at_end(update_bb); + update(generator, ctx, loop_var)?; + ctx.builder.build_unconditional_branch(cond_bb); + + ctx.builder.position_at_end(cont_bb); + ctx.loop_target = loop_bb; + + Ok(()) +} + /// See [`CodeGenerator::gen_while`]. pub fn gen_while( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 00da280f..cb3f650f 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -13,11 +13,12 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, + toplevel::numpy::gen_ndarray_empty, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{BasicType, BasicMetadataTypeEnum}, - values::BasicMetadataValueEnum, + values::{BasicValue, BasicMetadataValueEnum}, FloatPredicate, IntPredicate }; @@ -278,6 +279,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; + let ndarray_float = { + let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); + primitives.1.add_ty(ndarray_ty_enum) + }; + let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = primitives.1.get_fresh_var_with_range( &[int32, int64, float, boolean, uint32, uint64], Some("N".into()), @@ -837,6 +843,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + create_fn_by_codegen( + primitives, + &var_map, + "np_ndarray", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_empty(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "np_empty", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_empty(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), create_fn_by_codegen( primitives, &var_map, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c62c5c0a..c204819a 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize); pub mod builtins; pub mod composer; pub mod helper; +pub mod numpy; pub mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs new file mode 100644 index 00000000..9b91e826 --- /dev/null +++ b/nac3core/src/toplevel/numpy.rs @@ -0,0 +1,198 @@ +use inkwell::{ + IntPredicate, + types::BasicType, + values::PointerValue, +}; +use nac3parser::ast::StrRef; +use crate::{ + codegen::{ + CodeGenContext, + CodeGenerator, + irrt::{call_ndarray_calc_size, call_ndarray_init_dims}, + stmt::gen_for_callback + }, + symbol_resolver::ValueEnum, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type, TypeEnum}, +}; + +/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `var_name` - The variable name of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + var_name: Option<&str>, + shape: PointerValue<'ctx>, +) -> Result, String> { + let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); + let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + + // Assert that all dimensions are non-negative + gen_for_callback( + generator, + ctx, + |_, ctx| { + let i = ctx.builder.build_alloca(llvm_usize, ""); + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let shape_len = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_int_value(); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, "")) + }, + |generator, ctx, i_addr| { + let shape_elems = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None + ).into_pointer_value(); + + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let shape_dim = ctx.build_gep_and_load( + shape_elems, + &[i], + None + ).into_int_value(); + + let shape_dim_gez = ctx.builder.build_int_compare( + IntPredicate::SGE, + shape_dim, + llvm_i32.const_zero(), + "" + ); + + ctx.make_assert( + generator, + shape_dim_gez, + "0:ValueError", + "negative dimensions not supported", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + )?; + + let ndarray = ctx.builder.build_alloca( + llvm_ndarray_t, + var_name.unwrap_or_default() + ); + + let num_dims = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None + ).into_int_value(); + + let ndarray_num_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "", + ) + }; + ctx.builder.build_store(ndarray_num_dims, num_dims); + + let ndarray_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "", + ) + }; + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_store( + ndarray_dims, + ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ), + ); + + call_ndarray_init_dims(generator, ctx, ndarray, shape); + + let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape); + + let ndarray_data = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + "", + ) + }; + ctx.builder.build_store( + ndarray_data, + ctx.builder.build_array_alloca( + llvm_ndarray_data_t, + ndarray_num_elems, + "", + ), + ); + + Ok(ndarray) +} + +/// Generates LLVM IR for `ndarray.empty`. +pub fn gen_ndarray_empty<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg_name = args[0].0; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_impl( + generator, + context, + context.primitives.float, + shape_arg_name.map(|name| name.to_string()).as_deref(), + shape_arg.into_pointer_value(), + ) +} \ No newline at end of file diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index f9d33966..cfe000fa 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; -use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; +use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext}; use itertools::izip; use nac3parser::ast::{ self, @@ -894,6 +894,53 @@ impl<'a> Inferencer<'a> { } } + // 1-argument ndarray n-dimensional creation functions + if [ + "np_ndarray".into(), + "np_empty".into(), + ].contains(id) && args.len() == 1 { + let ExprKind::List { elts, .. } = &args[0].node else { + return report_error("Expected List literal for first argument of np_ndarray", args[0].location) + }; + + let ndims = elts.len() as u64; + + let arg0 = self.fold_expr(args.remove(0))?; + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + let ret = self.unifier.add_ty(TypeEnum::TNDArray { + ty: self.primitives.float, + ndims + }); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "shape".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: HashMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + Ok(None) } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 830e86d9..abdeda97 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -5,11 +5,12 @@ import importlib.util import importlib.machinery import math import numpy as np +import numpy.typing as npt import pathlib from numpy import int32, int64, uint32, uint64 from scipy import special -from typing import TypeVar, Generic, Literal +from typing import TypeVar, Generic, Literal, Union T = TypeVar('T') class Option(Generic[T]): @@ -50,6 +51,13 @@ class _ConstGenericMarker: def ConstGeneric(name, constraint): return TypeVar(name, _ConstGenericMarker, constraint) +N = TypeVar("N", bound=np.uint64) +class _NDArrayDummy(Generic[T, N]): + pass + +# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic +NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]] + def round_away_zero(x): if x >= 0.0: return math.floor(x + 0.5) @@ -124,6 +132,16 @@ def patch(module): module.ceil64 = math.ceil module.np_ceil = np.ceil + # NumPy ndarray functions + module.ndarray = NDArray + module.np_ndarray = np.ndarray + module.np_empty = np.empty + module.np_zeros = np.zeros + module.np_ones = np.ones + module.np_full = np.full + module.np_eye = np.eye + module.np_identity = np.identity + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf @@ -166,6 +184,9 @@ def patch(module): module.sp_spec_j0 = special.j0 module.sp_spec_j1 = special.j1 + # NumPy NDArray Functions + module.np_ndarray = np.ndarray + module.np_empty = np.empty def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py new file mode 100644 index 00000000..1237d069 --- /dev/null +++ b/nac3standalone/demo/src/ndarray.py @@ -0,0 +1,22 @@ +def consume_ndarray_1(n: ndarray[float, Literal[1]]): + pass + +def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]): + pass + +def consume_ndarray_2(n: ndarray[float, Literal[2]]): + pass + +def test_ndarray_ctor(): + n = np_ndarray([1]) + consume_ndarray_1(n) + +def test_ndarray_empty(): + n = np_empty([1]) + consume_ndarray_1(n) + +def run() -> int32: + test_ndarray_ctor() + test_ndarray_empty() + + return 0