From 49a7469b4a80d534cbf167b7d2b8d616e8d7c649 Mon Sep 17 00:00:00 2001 From: ram Date: Mon, 30 Dec 2024 13:02:09 +0800 Subject: [PATCH 01/46] use memcmp for string comparison Co-authored-by: ram Co-committed-by: ram --- nac3core/irrt/irrt.cpp | 2 + nac3core/irrt/irrt/string.hpp | 23 ++++++ nac3core/src/codegen/expr.rs | 108 ++++++---------------------- nac3core/src/codegen/irrt/mod.rs | 2 + nac3core/src/codegen/irrt/string.rs | 48 +++++++++++++ 5 files changed, 95 insertions(+), 88 deletions(-) create mode 100644 nac3core/irrt/irrt/string.hpp create mode 100644 nac3core/src/codegen/irrt/string.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 8447fc5a..722ed32d 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -4,7 +4,9 @@ #include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/string.hpp" diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 00000000..f695dcdc --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "irrt/int_types.hpp" + +namespace { +template +SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (len1 != len2){ + return 0; + } + return (__builtin_memcmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; +} +} // namespace + +extern "C" { +uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} + +uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0118ca43..c616449f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -24,7 +24,7 @@ use super::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_int_umin, call_memcpy_generic, + call_memcpy_generic, }, macros::codegen_unreachable, need_sret, numpy, @@ -2045,111 +2045,43 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let lhs = lhs.into_struct_value(); let rhs = rhs.into_struct_value(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(prhs, rhs).unwrap(); + let lhs_ptr = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); + + let rhs_ptr = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - - let len = call_int_umin(ctx, lhs_len, rhs_len, None); - - let current_bb = ctx.builder.get_insert_block().unwrap(); - let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); - ctx.builder.position_at_end(current_bb); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let lhs_char = { - let plhs_data = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - plhs_data, - &[i], - None - ).into_int_value() - }; - let rhs_char = { - let prhs_data = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - prhs_data, - &[i], - None - ).into_int_value() - }; - - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) - }, - |_, ctx| { - let bb = ctx.builder.get_insert_block().unwrap(); - cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let bb = ctx.builder.get_insert_block().unwrap(); - let is_len_eq = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_len, - rhs_len, - "", - ).unwrap(); - cmp_phi.add_incoming(&[(&is_len_eq, bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = cmp_phi.as_basic_value().into_int_value(); - - // Invert the final value if __ne__ + let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { - cmp_phi + result } } else if [left_ty, right_ty] .iter() diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 824921cd..21a16bdb 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -15,12 +15,14 @@ pub use list::*; pub use math::*; pub use range::*; pub use slice::*; +pub use string::*; mod list; mod math; pub mod ndarray; mod range; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 00000000..fb0f27b9 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,48 @@ +use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; +use itertools::Either; + +use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator}; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() { + 32 => ("nac3_str_eq", ctx.ctx.i32_type()), + 64 => ("nac3_str_eq64", ctx.ctx.i64_type()), + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + + let func = ctx.module.get_function(func_name).unwrap_or_else(|| { + ctx.module.add_function( + func_name, + return_type.fn_type( + &[ + str1_ptr.get_type().into(), + str1_len.get_type().into(), + str2_ptr.get_type().into(), + str2_len.get_type().into(), + ], + false, + ), + None, + ) + }); + let result = ctx + .builder + .build_call( + func, + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + "str_eq_call", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + generator.bool_to_i1(ctx, result) +} From 456aefa6ee1e71386e1dacc2a7e569a5ebc70f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Mon, 30 Dec 2024 13:03:31 +0800 Subject: [PATCH 02/46] clean up duplicate include --- nac3core/irrt/irrt.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 722ed32d..0d069869 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,3 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" -#include "irrt/string.hpp" From fbf0053c248a58421829b628b7c0f62be62a77f1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 30 Dec 2024 14:04:42 +0800 Subject: [PATCH 03/46] [core] irrt/string: Minor cleanup - Refactor __nac3_str_eq to always return bool - Use `get_usize_dependent_function_name` to get IRRT func name --- nac3core/irrt/irrt/string.hpp | 8 ++++---- nac3core/src/codegen/irrt/string.rs | 24 +++++++++++------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index f695dcdc..db3ad7fd 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -4,20 +4,20 @@ namespace { template -SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { +bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { if (len1 != len2){ return 0; } - return (__builtin_memcmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; + return __builtin_memcmp(str1, str2, static_cast(len1)) == 0; } } // namespace extern "C" { -uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { +bool nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } -uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { +bool nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index fb0f27b9..6ee40e45 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,7 +1,8 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; -use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator}; +use super::get_usize_dependent_function_name; +use crate::codegen::{CodeGenContext, CodeGenerator}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( @@ -12,16 +13,14 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( str2_ptr: PointerValue<'ctx>, str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { - let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() { - 32 => ("nac3_str_eq", ctx.ctx.i32_type()), - 64 => ("nac3_str_eq64", ctx.ctx.i64_type()), - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let llvm_i1 = ctx.ctx.bool_type(); - let func = ctx.module.get_function(func_name).unwrap_or_else(|| { + let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq"); + + let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { ctx.module.add_function( - func_name, - return_type.fn_type( + &func_name, + llvm_i1.fn_type( &[ str1_ptr.get_type().into(), str1_len.get_type().into(), @@ -33,8 +32,8 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( None, ) }); - let result = ctx - .builder + + ctx.builder .build_call( func, &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], @@ -43,6 +42,5 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) - .unwrap(); - generator.bool_to_i1(ctx, result) + .unwrap() } From 0e5940c49d0b71b4bd19e7e7b7fdb6a705f46dbe Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:43:24 +0800 Subject: [PATCH 04/46] [meta] Refactor itertools::{chain,enumerate,repeat_n} with std equiv --- nac3core/src/codegen/expr.rs | 4 ++-- nac3core/src/symbol_resolver.rs | 12 ++++++------ nac3core/src/toplevel/composer.rs | 2 +- nac3core/src/typecheck/typedef/mod.rs | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c616449f..606e7421 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -11,7 +11,7 @@ use inkwell::{ values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::{chain, izip, Either, Itertools}; +use itertools::{izip, Either, Itertools}; use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, @@ -1965,7 +1965,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } } - let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) + let cmp_val = izip!(once(&left).chain(comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; let (right_ty, rhs) = rhs; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index bab823c1..2378dd62 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -6,7 +6,7 @@ use std::{ }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; -use itertools::{chain, izip, Itertools}; +use itertools::{izip, Itertools}; use parking_lot::RwLock; use nac3parser::ast::{Constant, Expr, Location, StrRef}; @@ -452,11 +452,11 @@ pub fn parse_type_annotation( type_vars.len() )])); } - let fields = chain( - fields.iter().map(|(k, v, m)| (*k, (*v, *m))), - methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(); + let fields = fields + .iter() + .map(|(k, v, m)| (*k, (*v, *m))) + .chain(methods.iter().map(|(k, v, _)| (*k, (*v, false)))) + .collect(); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() })) } else { Err(HashSet::from([format!("Cannot use function name as type at {loc}")])) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 6040ced1..bd9a9214 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1052,7 +1052,7 @@ impl TopLevelComposer { } let mut result = Vec::new(); let no_defaults = args.args.len() - args.defaults.len() - 1; - for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) { + for (idx, x) in args.args.iter().skip(1).enumerate() { let type_ann = { let Some(annotation_expr) = x.node.annotation.as_ref() else {return Err(HashSet::from([format!("type annotation needed for `{}` (at {})", x.node.arg, x.location)]));}; parse_ast_to_type_annotation_kinds( diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 49cea04c..e190c4c4 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -3,13 +3,13 @@ use std::{ cell::RefCell, collections::{HashMap, HashSet}, fmt::{self, Display}, - iter::{repeat, zip}, + iter::{repeat, repeat_n, zip}, rc::Rc, sync::{Arc, Mutex}, }; use indexmap::IndexMap; -use itertools::{repeat_n, Itertools}; +use itertools::Itertools; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; From 35e3042435813ac07a129482cb7a382b9c840744 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 13:58:02 +0800 Subject: [PATCH 05/46] [core] Refactor/Remove redundant and unused constructs - Use ProxyValue.name where necessary - Remove NDArrayValue::ptr_to_{shape,strides} - Remove functions made obsolete by ndstrides - Remove use statement for ndarray::views as it only contain an impl block. - Remove class_names field in Resolvers of test sources --- nac3core/src/codegen/expr.rs | 324 +----------------- nac3core/src/codegen/llvm_intrinsics.rs | 33 +- nac3core/src/codegen/numpy.rs | 23 -- nac3core/src/codegen/test.rs | 15 +- nac3core/src/codegen/values/array.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 13 - nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/toplevel/builtins.rs | 146 +------- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/toplevel/test.rs | 18 +- .../src/typecheck/type_inferencer/test.rs | 4 - 15 files changed, 23 insertions(+), 573 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 606e7421..c3dfc80e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,19 +34,14 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::{NDArrayValue, RustNDIndex}, - ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, - TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::{extract_ndims, PrimDef}, - numpy::unpack_ndarray_var_tys, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -2444,319 +2439,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims_ty: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims_ty) else { - codegen_unreachable!(ctx) - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - let num_dims = extract_ndims(&ctx.unifier, ndims_ty) - 1; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let ndarray = - NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t, Some(num_dims)) - .construct_uninitialized(generator, ctx, None); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.shape().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.shape().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = ndarray::call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - unsafe { ndarray.create_data(generator, ctx) }; - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index 895339d0..9c360b6a 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -1,7 +1,6 @@ use inkwell::{ - context::Context, intrinsics::Intrinsic, - types::{AnyTypeEnum::IntType, FloatType}, + types::AnyTypeEnum::IntType, values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}, AddressSpace, }; @@ -9,34 +8,6 @@ use itertools::Either; use super::CodeGenContext; -/// Returns the string representation for the floating-point type `ft` when used in intrinsic -/// functions. -fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { - // Standard LLVM floating-point types - if ft == ctx.f16_type() { - return "f16"; - } - if ft == ctx.f32_type() { - return "f32"; - } - if ft == ctx.f64_type() { - return "f64"; - } - if ft == ctx.f128_type() { - return "f128"; - } - - // Non-standard floating-point types - if ft == ctx.x86_f80_type() { - return "f80"; - } - if ft == ctx.ppc_f128_type() { - return "ppcf128"; - } - - unreachable!() -} - /// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) /// intrinsic. pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { @@ -54,7 +25,7 @@ pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap(); } -/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) +/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic) /// intrinsic. pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { const FN_NAME: &str = "llvm.va_end"; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index cd113aae..513f8f51 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -604,29 +604,6 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( } } -/// Returns the number of dimensions for an array-like object as an [`IntValue`]. -fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (ty, value): (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match value { - BasicValueEnum::PointerValue(v) - if NDArrayValue::is_representable(v, llvm_usize).is_ok() => - { - NDArrayType::from_unifier_type(generator, ctx, ty).map_value(v, None).load_ndims(ctx) - } - - BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { - llvm_ndlist_get_ndims(generator, ctx, v.get_type()) - } - - _ => llvm_usize.const_zero(), - } -} - /// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 81c5836c..97bd3f09 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -36,7 +36,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: RwLock>, - class_names: HashMap, } impl Resolver { @@ -104,11 +103,9 @@ fn test_primitives() { let top_level = Arc::new(composer.make_top_level_context()); unifier.top_level = Some(top_level.clone()); - let resolver = Arc::new(Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }) as Arc; + let resolver = + Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) + as Arc; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let signature = FunSignature { @@ -298,11 +295,7 @@ fn test_simple_call() { loc: None, }))); - let resolver = Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }; + let resolver = Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }; resolver.add_id_def("foo".into(), DefinitionId(foo_id)); let resolver = Arc::new(resolver) as Arc; diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 78975f06..e6ebb258 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -389,7 +389,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + let var_name = name.or(self.2).map(|v| format!("{v}.addr")).unwrap_or_default(); unsafe { ctx.builder diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 12fd8634..eef74407 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -19,7 +19,6 @@ use crate::codegen::{ pub use contiguous::*; pub use indexing::*; pub use nditer::*; -pub use view::*; mod contiguous; mod indexing; @@ -113,12 +112,6 @@ impl<'ctx> NDArrayValue<'ctx> { self.get_type().get_fields(ctx.ctx).shape } - /// Returns the double-indirection pointer to the `shape` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name) - } - /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); @@ -147,12 +140,6 @@ impl<'ctx> NDArrayValue<'ctx> { self.get_type().get_fields(ctx.ctx).strides } - /// Returns the double-indirection pointer to the `strides` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name) - } - /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 45a82b38..218afc19 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -78,7 +78,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element(ctx).get(ctx, self.as_base_value(), None); + let p = self.element(ctx).get(ctx, self.as_base_value(), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -98,7 +98,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth(ctx).get(ctx, self.as_base_value(), None) + self.nth(ctx).get(ctx, self.as_base_value(), self.name) } /// Get the indices of the current element. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e382222c..36bd85d1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,13 +1,7 @@ use std::iter::once; use indexmap::IndexMap; -use inkwell::{ - attributes::{Attribute, AttributeLoc}, - types::{BasicMetadataTypeEnum, BasicType}, - values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, - IntPredicate, -}; -use itertools::Either; +use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ @@ -148,144 +142,6 @@ fn create_fn_by_codegen( } } -/// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic. -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function. -fn create_fn_by_intrinsic( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - intrinsic_fn: &'static str, -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - - ctx.module.add_function(intrinsic_fn, fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, args_val.as_slice(), name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - -/// Creates a unary NumPy [`TopLevelDef`] function using an extern function (e.g. from `libc` or -/// `libm`). -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation. -/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is -/// already implied by the C ABI. -fn create_fn_by_extern( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - extern_fn: &'static str, - attrs: &'static [&str], -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - let func = ctx.module.add_function(extern_fn, fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), - ); - - for attr in attrs { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - - func - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &args_val, name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo { BuiltinBuilder::new(unifier, primitives) .build_all_builtins() diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 1b0c9b80..93f2096b 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2621337c..d3301d00 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d0769305..911426b9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 5ebdf86c..d60daf83 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a9c4dd6..517f6846 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 37a3dede..1f33a4ba 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -15,14 +15,13 @@ use crate::{ symbol_resolver::{SymbolResolver, ValueEnum}, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{into_var_map, Type, Unifier}, + typedef::{Type, Unifier}, }, }; struct ResolverInternal { id_to_type: Mutex>, id_to_def: Mutex>, - class_names: Mutex>, } impl ResolverInternal { @@ -179,11 +178,8 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { let mut composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; - let internal_resolver = Arc::new(ResolverInternal { - id_to_def: Mutex::default(), - id_to_type: Mutex::default(), - class_names: Mutex::default(), - }); + let internal_resolver = + Arc::new(ResolverInternal { id_to_def: Mutex::default(), id_to_type: Mutex::default() }); let resolver = Arc::new(Resolver(internal_resolver.clone())) as Arc; @@ -784,13 +780,6 @@ fn make_internal_resolver_with_tvar( unifier: &mut Unifier, print: bool, ) -> Arc { - let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None); - let list = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimDef::List.id(), - fields: HashMap::new(), - params: into_var_map([list_elem_tvar]), - }); - let res: Arc = ResolverInternal { id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])), id_to_type: tvars @@ -806,7 +795,6 @@ fn make_internal_resolver_with_tvar( }) .collect::>() .into(), - class_names: Mutex::new(HashMap::from([("list".into(), list)])), } .into(); if print { diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index e56cb283..a6583536 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -18,7 +18,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: HashMap, - class_names: HashMap, } impl SymbolResolver for Resolver { @@ -198,7 +197,6 @@ impl TestEnvironment { let resolver = Arc::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: HashMap::default(), - class_names: HashMap::default(), }) as Arc; TestEnvironment { @@ -454,7 +452,6 @@ impl TestEnvironment { vars: IndexMap::default(), })), ); - let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into(); let id_to_name = [ "int32".into(), @@ -492,7 +489,6 @@ impl TestEnvironment { ("Bar2".into(), DefinitionId(defs + 3)), ] .into(), - class_names, }) as Arc; TestEnvironment { From 318371a5093f04afe9e029fa4b2cf2d60accea06 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 20 Dec 2024 13:19:40 +0800 Subject: [PATCH 06/46] [core] irrt: Minor cleanup --- nac3core/irrt/irrt/ndarray.hpp | 2 +- nac3core/irrt/irrt/ndarray/basic.hpp | 6 ++---- nac3core/irrt/irrt/ndarray/indexing.hpp | 9 ++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 9d305aa4..72ca0b9e 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -148,4 +148,4 @@ void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, NDIndexInt* out_idx) { __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); } -} // namespace \ No newline at end of file +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp index 05ee30fc..62c92ae2 100644 --- a/nac3core/irrt/irrt/ndarray/basic.hpp +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -6,8 +6,7 @@ #include "irrt/ndarray/def.hpp" namespace { -namespace ndarray { -namespace basic { +namespace ndarray::basic { /** * @brief Assert that `shape` does not contain negative dimensions. * @@ -247,8 +246,7 @@ void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); } } -} // namespace basic -} // namespace ndarray +} // namespace ndarray::basic } // namespace extern "C" { diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index 9e9e7b6e..76e78473 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -65,8 +65,7 @@ struct NDIndex { } // namespace namespace { -namespace ndarray { -namespace indexing { +namespace ndarray::indexing { /** * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) * @@ -162,7 +161,8 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ Range range = slice->indices_checked(src_ndarray->shape[src_axis]); - dst_ndarray->data = static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->data = + static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; dst_ndarray->shape[dst_axis] = (SizeT)range.len(); @@ -197,8 +197,7 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); } -} // namespace indexing -} // namespace ndarray +} // namespace ndarray::indexing } // namespace extern "C" { From 19122e29050315740302b3a2b7146e1aa3968ce0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 13 Dec 2024 16:35:34 +0800 Subject: [PATCH 07/46] [core] codegen: Rename classes/functions for consistency - ContiguousNDArrayFields -> ContiguousNDArrayStructFields - ndarray/nditer: Add _field suffix to field accessors --- nac3core/src/codegen/types/ndarray/contiguous.rs | 14 +++++++------- nac3core/src/codegen/values/ndarray/nditer.rs | 11 +++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 4401cb62..317539c0 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -31,7 +31,7 @@ pub struct ContiguousNDArrayType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct ContiguousNDArrayFields<'ctx> { +pub struct ContiguousNDArrayStructFields<'ctx> { #[value_type(usize)] pub ndims: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize.ptr_type(AddressSpace::default()))] @@ -40,12 +40,12 @@ pub struct ContiguousNDArrayFields<'ctx> { pub data: StructField<'ctx, PointerValue<'ctx>>, } -impl<'ctx> ContiguousNDArrayFields<'ctx> { +impl<'ctx> ContiguousNDArrayStructFields<'ctx> { #[must_use] pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let mut counter = FieldIndexCounter::default(); - ContiguousNDArrayFields { + ContiguousNDArrayStructFields { ndims: StructField::create(&mut counter, "ndims", llvm_usize), shape: StructField::create( &mut counter, @@ -72,7 +72,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { )); }; - let fields = ContiguousNDArrayFields::new(ctx, llvm_usize); + let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); check_struct_type_matches_fields( fields, @@ -93,14 +93,14 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { fn fields( item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, - ) -> ContiguousNDArrayFields<'ctx> { - ContiguousNDArrayFields::new_typed(item, llvm_usize) + ) -> ContiguousNDArrayStructFields<'ctx> { + ContiguousNDArrayStructFields::new_typed(item, llvm_usize) } /// See [`NDArrayType::fields`]. // TODO: Move this into e.g. StructProxyType #[must_use] - pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> { + pub fn get_fields(&self) -> ContiguousNDArrayStructFields<'ctx> { Self::fields(self.item, self.llvm_usize) } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 218afc19..4b31f274 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -69,7 +69,10 @@ impl<'ctx> NDIterValue<'ctx> { irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); } - fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + fn element_field( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> StructField<'ctx, PointerValue<'ctx>> { self.get_type().get_fields(ctx.ctx).element } @@ -78,7 +81,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element(ctx).get(ctx, self.as_base_value(), self.name); + let p = self.element_field(ctx).get(ctx, self.as_base_value(), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -91,14 +94,14 @@ impl<'ctx> NDIterValue<'ctx> { ctx.builder.build_load(p, "value").unwrap() } - fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + fn nth_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { self.get_type().get_fields(ctx.ctx).nth } /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth(ctx).get(ctx, self.as_base_value(), self.name) + self.nth_field(ctx).get(ctx, self.as_base_value(), self.name) } /// Get the indices of the current element. From dc413dfa4331ff68664ab10d45828d124b166bb7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 18:03:03 +0800 Subject: [PATCH 08/46] [core] codegen: Refactor TypedArrayLikeAdapter to use fn Allows for greater flexibility when TypedArrayLikeAdapter is used with custom value types. --- nac3core/src/codegen/irrt/ndarray/mod.rs | 20 ++- nac3core/src/codegen/numpy.rs | 2 +- nac3core/src/codegen/values/array.rs | 128 +++++++++--------- nac3core/src/codegen/values/list.rs | 4 +- nac3core/src/codegen/values/ndarray/mod.rs | 54 +++++--- nac3core/src/codegen/values/ndarray/nditer.rs | 17 +-- 6 files changed, 117 insertions(+), 108 deletions(-) diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index a05e0ce3..56d9094d 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -87,7 +87,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_void = ctx.ctx.void_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -129,8 +129,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ) } @@ -227,7 +227,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); @@ -326,11 +326,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) + TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into()) } /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] @@ -345,7 +341,7 @@ pub fn call_ndarray_calc_broadcast_index< ctx: &mut CodeGenContext<'ctx, '_>, array: NDArrayValue<'ctx>, broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); @@ -385,7 +381,7 @@ pub fn call_ndarray_calc_broadcast_index< TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ) } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 513f8f51..30a33f08 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -356,7 +356,7 @@ where ValueFn: Fn( &mut G, &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, + &TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, ) -> Result, String>, { ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index e6ebb258..9f3ec0e4 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -51,8 +51,8 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { /// This function should be called with a valid index. unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx>; @@ -76,8 +76,8 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn get_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> BasicValueEnum<'ctx> { @@ -107,8 +107,8 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn set_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: BasicValueEnum<'ctx>, ) { @@ -130,32 +130,33 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: } /// An array-like value that can have its array elements accessed as an arbitrary type `T`. -pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeAccessor<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeAccessor<'ctx, Index> { /// Casts an element from [`BasicValueEnum`] into `T`. fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T; /// # Safety /// /// This function should be called with a valid index. - unsafe fn get_typed_unchecked( + unsafe fn get_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> T { let value = unsafe { self.get_unchecked(ctx, generator, idx, name) }; - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } /// Returns the data at the `idx`-th index. - fn get_typed( + fn get_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, @@ -163,62 +164,62 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: name: Option<&str>, ) -> T { let value = self.get(ctx, generator, idx, name); - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } } /// An array-like value that can have its array elements mutated as an arbitrary type `T`. -pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeMutator<'ctx, Index> { /// Casts an element from T into [`BasicValueEnum`]. fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx>; /// # Safety /// /// This function should be called with a valid index. - unsafe fn set_typed_unchecked( + unsafe fn set_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); unsafe { self.set_unchecked(ctx, generator, idx, value) } } /// Sets the data at the `idx`-th index. - fn set_typed( + fn set_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); self.set(ctx, generator, idx, value); } } -/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. -type ValueDowncastFn<'ctx, T> = - Box, BasicValueEnum<'ctx>) -> T + 'ctx>; -/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. -type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; - /// An adapter for constraining untyped array values as typed values. -pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> { +pub struct TypedArrayLikeAdapter< + 'ctx, + G: CodeGenerator + ?Sized, + T, + Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>, +> { adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, } -impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { @@ -229,61 +230,62 @@ where /// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`]. pub fn from( adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, ) -> Self { TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn } } } -impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> ArrayLikeValue<'ctx> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { - fn element_type( + fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> AnyTypeEnum<'ctx> { self.adapted.element_type(ctx, generator) } - fn base_ptr( + fn base_ptr( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> PointerValue<'ctx> { self.adapted.base_ptr(ctx, generator) } - fn size( + fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } } -impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeIndexer<'ctx, Index>, { - unsafe fn ptr_offset_unchecked( + unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) } } - fn ptr_offset( + fn ptr_offset( &self, ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + generator: &mut CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -291,44 +293,46 @@ where } } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T { - (self.downcast_fn)(ctx, value) + (self.downcast_fn)(ctx, generator, value) } } -impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeMutator<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx> { - (self.upcast_fn)(ctx, value) + (self.upcast_fn)(ctx, generator, value) } } @@ -384,8 +388,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> { impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 7b1975f0..549bfe3f 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -199,8 +199,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index eef74407..0da3a2ee 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -478,8 +478,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -517,20 +517,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_ impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -570,8 +576,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -609,20 +615,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -667,8 +679,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -748,17 +760,19 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { let llvm_usize = generator.get_size_type(ctx.ctx); - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); + let indices_elem_ty = unsafe { + indices + .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .get_type() + .get_element_type() + }; let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { panic!("Expected list[int32] but got {indices_elem_ty}") }; diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 4b31f274..e29770e0 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -4,7 +4,7 @@ use inkwell::{ AddressSpace, }; -use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator}; +use super::{NDArrayValue, ProxyValue}; use crate::codegen::{ irrt, stmt::{gen_for_callback, BreakContinueHooks}, @@ -106,18 +106,13 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the indices of the current element. #[must_use] - pub fn get_indices( - &'ctx self, - ) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>> - { + pub fn get_indices( + &self, + ) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { TypedArrayLikeAdapter::from( self.indices, - Box::new(|ctx, val| { - ctx.builder - .build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "") - .unwrap() - }), - Box::new(|_, val| val.into()), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), ) } } From d5e8df070adc1c91595881aacc382b79ab887310 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 16:10:00 +0800 Subject: [PATCH 09/46] [core] Minor improvements to IRRT and add missing documentation --- nac3core/src/codegen/generator.rs | 1 + nac3core/src/codegen/irrt/list.rs | 52 ++++++---- nac3core/src/codegen/irrt/math.rs | 20 +++- nac3core/src/codegen/irrt/mod.rs | 11 ++- nac3core/src/codegen/irrt/ndarray/basic.rs | 88 ++++++++++++++--- nac3core/src/codegen/irrt/ndarray/indexing.rs | 5 + nac3core/src/codegen/irrt/ndarray/iter.rs | 20 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 99 ++++++++----------- nac3core/src/codegen/irrt/range.rs | 18 +++- nac3core/src/codegen/types/ndarray/nditer.rs | 19 ++-- nac3core/src/codegen/values/array.rs | 8 ++ nac3core/src/codegen/values/ndarray/mod.rs | 4 + nac3core/src/codegen/values/ndarray/nditer.rs | 4 + 13 files changed, 238 insertions(+), 111 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index f277ec9a..be007c2a 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -17,6 +17,7 @@ pub trait CodeGenerator { /// Return the module name for the code generator. fn get_name(&self) -> &str; + /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index a7fec59d..2c57f8e7 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,42 +24,52 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(dest_idx.0.get_type(), llvm_i32); + assert_eq!(dest_idx.1.get_type(), llvm_i32); + assert_eq!(dest_idx.2.get_type(), llvm_i32); + assert_eq!(src_idx.0.get_type(), llvm_i32); + assert_eq!(src_idx.1.get_type(), llvm_i32); + assert_eq!(src_idx.2.get_type(), llvm_i32); + + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); let slice_assign_fun = { let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step + llvm_i32.into(), // dest start idx + llvm_i32.into(), // dest end idx + llvm_i32.into(), // dest step elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step + llvm_i32.into(), // dest arr len + llvm_i32.into(), // src start idx + llvm_i32.into(), // src end idx + llvm_i32.into(), // src step elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size + llvm_i32.into(), // src arr len + llvm_i32.into(), // size ]; ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); + let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); ctx.module.add_function(fun_symbol, fn_t, None) }) }; - let zero = int32.const_zero(); - let one = int32.const_int(1, false); + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let dest_len = + ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + let src_len = + ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and @@ -136,7 +146,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( BasicTypeEnum::StructType(t) => t.size_of().unwrap(), _ => codegen_unreachable!(ctx), }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap() } .into(), ]; @@ -147,6 +157,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( .map(Either::unwrap_left) .unwrap() }; + // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); @@ -155,7 +166,8 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 4bc95913..33445b2a 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -62,8 +62,13 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isinf", fn_type, None) }); @@ -84,8 +89,13 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isnan", fn_type, None) }); @@ -104,6 +114,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gamma", fn_type, None) @@ -121,6 +133,8 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gammaln", fn_type, None) @@ -138,6 +152,8 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_j0", fn_type, None) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 21a16bdb..4cacdccb 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -132,10 +132,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( generator: &mut G, length: IntValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap(); + let llvm_i32 = ctx.ctx.i32_type(); + + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); + let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap(); Ok(Some(match (start, end, step) { (s, e, None) => ( if let Some(s) = s.as_ref() { @@ -144,7 +145,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( None => return Ok(None), } } else { - int32.const_zero() + llvm_i32.const_zero() }, { let e = if let Some(s) = e.as_ref() { diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 0daea1c4..d11c9b8d 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; @@ -7,19 +8,26 @@ use crate::codegen::{ expr::{create_and_call_function, infer_and_call_function}, irrt::get_usize_dependent_function_name, types::ProxyType, - values::{ndarray::NDArrayValue, ProxyValue}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`. +/// +/// Assets that `shape` does not contain negative dimensions. pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - shape: PointerValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -30,23 +38,37 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ctx, &name, Some(llvm_usize.into()), - &[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())], + &[ + (llvm_usize.into(), shape.size(ctx, generator).into()), + (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), + ], None, None, ); } +/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`. +/// +/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to +/// an `ndarray`. pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndarray_ndims: IntValue<'ctx>, - ndarray_shape: PointerValue<'ctx>, - output_ndims: IntValue<'ctx>, - output_shape: IntValue<'ctx>, + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -58,16 +80,20 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_ndims.into()), - (llvm_pusize.into(), ndarray_shape.into()), - (llvm_usize.into(), output_ndims.into()), - (llvm_pusize.into(), output_shape.into()), + (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), + (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), + (llvm_usize.into(), output_shape.size(ctx, generator).into()), + (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), ], None, None, ); } +/// Generates a call to `__nac3_ndarray_size`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an +/// `ndarray`, corresponding to the value of `ndarray.size`. pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -90,6 +116,10 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_nbytes`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the +/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -112,6 +142,10 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_len`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of +/// the `ndarray`, corresponding to the value of `ndarray.__len__`. pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -134,6 +168,9 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_is_c_contiguous`. +/// +/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -156,6 +193,9 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_nth_pelement`. +/// +/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -167,6 +207,8 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!(index.get_type(), llvm_usize); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( @@ -181,11 +223,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`. +/// +/// `indices` must have the same number of elements as the number of dimensions in `ndarray`. +/// +/// Returns a [`PointerValue`] to the element indexed by `indices`. pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: PointerValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); @@ -193,6 +240,11 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); @@ -202,7 +254,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized Some(llvm_pi8.into()), &[ (llvm_ndarray.into(), ndarray.as_base_value().into()), - (llvm_pusize.into(), indices.into()), + (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), None, @@ -211,6 +263,9 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized .unwrap() } +/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. +/// +/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -231,6 +286,11 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_ndarray_copy_data`. +/// +/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number +/// of elements in `src_ndarray` must be greater than or equal to the number of elements in +/// `dst_ndarray`. pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 0821b2cd..3e2c908d 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -5,6 +5,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_index`. +/// +/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) +/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the +/// operation `dst_ndarray = src_ndarray[indices]`. pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 966d6605..47cd5b29 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue}, AddressSpace, }; @@ -9,21 +10,29 @@ use crate::codegen::{ types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArrayLikeValue, ArraySliceValue, ProxyValue, + ProxyValue, TypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_nditer_initialize`. +/// +/// Initializes the `iter` object. pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ndarray: NDArrayValue<'ctx>, - indices: ArraySliceValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); create_and_call_function( @@ -40,6 +49,10 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_nditer_initialize_has_element`. +/// +/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` +/// object. pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -59,6 +72,9 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_nditer_next`. +/// +/// Moves `iter` to point to the next element. pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56d9094d..b74ace0f 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,10 +1,11 @@ use inkwell::{ - types::IntType, + types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; +use super::get_usize_dependent_function_name; use crate::codegen::{ llvm_intrinsics, macros::codegen_unreachable, @@ -23,8 +24,8 @@ mod basic; mod indexing; mod iter; -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. +/// Generates a call to `__nac3_ndarray_calc_size`. Returns a +/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, @@ -43,18 +44,22 @@ where let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); + assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); + assert_eq!( + BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let ndarray_calc_size_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], false, ); let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { + ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); @@ -76,10 +81,10 @@ where .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] /// containing `i32` indices of the flattened index. /// -/// * `index` - The index to compute the multidimensional index for. +/// * `index` - The `llvm_usize` index to compute the multidimensional index for. /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// `NDArray`. pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( @@ -94,19 +99,18 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert_eq!(index.get_type(), llvm_usize); + + let ndarray_calc_nd_indices_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let fn_type = llvm_void.fn_type( &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], false, ); - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -134,15 +138,21 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for +/// the multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx, G, Index>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: &Indices, + indices: &Index, ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, + Index: ArrayLikeIndexer<'ctx>, { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -163,19 +173,16 @@ where "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" ); - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_flatten_index_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], false, ); - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -201,27 +208,8 @@ where index } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] +/// containing the size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -231,13 +219,10 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_calc_broadcast_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[ llvm_pusize.into(), @@ -249,7 +234,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( false, ); - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) }); let lhs_ndims = lhs.load_ndims(ctx); diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 47c63c4f..3b6bc31d 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -6,6 +6,13 @@ use itertools::Either; use crate::codegen::{CodeGenContext, CodeGenerator}; +/// Invokes the `__nac3_range_slice_len` in IRRT. +/// +/// - `start`: The `i32` start value for the slice. +/// - `end`: The `i32` end value for the slice. +/// - `step`: The `i32` step value for the slice. +/// +/// Returns an `i32` value of the length of the slice. pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -14,9 +21,15 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( step: IntValue<'ctx>, ) -> IntValue<'ctx> { const SYMBOL: &str = "__nac3_range_slice_len"; + + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(start.get_type(), llvm_i32); + assert_eq!(end.get_type(), llvm_i32); + assert_eq!(step.get_type(), llvm_i32); + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); @@ -33,6 +46,7 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); + ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .map(CallSiteValue::try_as_basic_value) diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c9b6b7d5..772d5b23 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -14,7 +14,7 @@ use crate::codegen::{ types::structure::{check_struct_type_matches_fields, StructField, StructFields}, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArraySliceValue, ProxyValue, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -128,6 +128,11 @@ impl<'ctx> NDIterType<'ctx> { } /// Allocate an [`NDIter`] that iterates through the given `ndarray`. + /// + /// Note: This function allocates an array on the stack at the current builder location, which + /// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to + /// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the + /// [`NDIter`] is no longer needed. #[must_use] pub fn construct( &self, @@ -141,16 +146,12 @@ impl<'ctx> NDIterType<'ctx> { // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap(); + let indices = + TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = >::Value::from_pointer_value( - nditer, - ndarray, - indices, - self.llvm_usize, - None, - ); + let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); - irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices); + irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); nditer } diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 9f3ec0e4..55e91b21 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -265,6 +265,14 @@ where ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } + + fn as_slice_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, + ) -> ArraySliceValue<'ctx> { + self.adapted.as_slice_value(ctx, generator) + } } impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 0da3a2ee..4c5be432 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -358,6 +358,10 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); } + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and + /// copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. #[must_use] pub fn make_copy( &self, diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index e29770e0..4b4e07a1 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -141,6 +141,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// get properties of the current iteration (e.g., the current element, indices, etc.) + /// + /// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and + /// after invoking this function respectively. See [`NDIterType::construct`] for an explanation + /// on why this is suggested. pub fn foreach<'a, G, F>( &self, generator: &mut G, From 3c0ce3031fb21476d7dedd6109d00d61e81fba61 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 13:44:14 +0800 Subject: [PATCH 10/46] [core] codegen: Update raw_alloca to return PointerValue Better match the expected behavior of alloca. --- nac3core/src/codegen/types/list.rs | 4 ++-- nac3core/src/codegen/types/mod.rs | 11 ++++++++--- nac3core/src/codegen/types/ndarray/contiguous.rs | 2 +- nac3core/src/codegen/types/ndarray/indexing.rs | 2 +- nac3core/src/codegen/types/ndarray/mod.rs | 2 +- nac3core/src/codegen/types/ndarray/nditer.rs | 2 +- nac3core/src/codegen/types/range.rs | 4 ++-- nac3core/src/codegen/types/utils/slice.rs | 4 ++-- 8 files changed, 18 insertions(+), 13 deletions(-) diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 3d041349..8de30867 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; @@ -167,7 +167,7 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 022f897b..03d6d387 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -16,7 +16,11 @@ //! the returned object. This is similar to a `new` expression in C++ but the object is allocated //! on the stack. -use inkwell::{context::Context, types::BasicType, values::IntValue}; +use inkwell::{ + context::Context, + types::BasicType, + values::{IntValue, PointerValue}, +}; use super::{ values::{ArraySliceValue, ProxyValue}, @@ -53,13 +57,14 @@ pub trait ProxyType<'ctx>: Into { llvm_ty: Self::Base, ) -> Result<(), String>; - /// Creates a new value of this type, returning the LLVM instance of this value. + /// Creates a new value of this type by invoking `alloca`, returning a [`PointerValue`] instance + /// representing the allocated value. fn raw_alloca( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base; + ) -> PointerValue<'ctx>; /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the /// resulting array. diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 317539c0..4be55474 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -218,7 +218,7 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 959d4f57..6bbd3e8b 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -176,7 +176,7 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 4127ffa8..b655f352 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -430,7 +430,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 772d5b23..ed98aa4b 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -203,7 +203,7 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index e704455b..dc241b6e 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; @@ -131,7 +131,7 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index aba0efb2..dd7643f7 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; use itertools::Itertools; @@ -215,7 +215,7 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, From dc9efa9e8c34eb398c2f28b828df122a70615591 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:08 +0800 Subject: [PATCH 11/46] [core] codegen/ndarray: Use IRRT for size() and indexing operations Also refactor some usages of call_ndarray_calc_size with ndarray.size(). --- nac3core/irrt/irrt/ndarray.hpp | 31 ------- nac3core/src/codegen/builtin_fns.rs | 14 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 76 +---------------- nac3core/src/codegen/numpy.rs | 98 +++++++++++++++++----- nac3core/src/codegen/values/ndarray/mod.rs | 80 ++++-------------- 5 files changed, 106 insertions(+), 193 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 72ca0b9e..7fc9a63b 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n } } -template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, - SizeT num_dims, - const NDIndexInt* indices, - SizeT num_indices) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - template void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims, @@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } -uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, - uint64_t num_dims, - const NDIndexInt* indices, - uint64_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims, const uint32_t* rhs_dims, diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a41b9f55..b21e721e 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = - irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); + let n_sz = n.size(generator, ctx); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx .builder @@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( llvm_int64.const_int(1, false), (n_sz, false), |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; + let elem = unsafe { + n.data().get_unchecked( + ctx, + generator, + &ctx.builder + .build_int_truncate_or_bit_cast(idx, llvm_usize, "") + .unwrap(), + None, + ) + }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index b74ace0f..56017c94 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{BasicTypeEnum, IntType}, + types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; @@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for -/// the multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); - let ndarray_flatten_index_fn = - ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] -/// containing the size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 30a33f08..9328bb83 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,8 +21,8 @@ use super::{ stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, types::{ndarray::NDArrayType, ListType, ProxyType}, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, + ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, @@ -318,12 +318,7 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); + let ndarray_num_elems = ndarray.size(generator, ctx); gen_for_callback_incrementing( generator, @@ -434,6 +429,66 @@ where rhs_val.get_type() ); + // Returns the element of an ndarray indexed by the given indices, performing int-promotion on + // `indices` where necessary. + // + // Required for compatibility with `NDArrayType::get_unchecked`. + let get_data_by_indices_compat = + |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT + let stackptr = llvm_intrinsics::call_stacksave(ctx, None); + let indices = if llvm_usize == ctx.ctx.i32_type() { + indices + } else { + let indices_usize = TypedArrayLikeAdapter::>::from( + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") + .unwrap(), + indices.size(ctx, generator), + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (indices.size(ctx, generator), false), + |generator, ctx, _, i| { + let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; + let idx = ctx + .builder + .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") + .unwrap(); + unsafe { + indices_usize.set_typed_unchecked(ctx, generator, &i, idx); + } + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + indices_usize + }; + + let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; + + llvm_intrinsics::call_stackrestore(ctx, stackptr); + + elem + }; + // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) @@ -455,7 +510,7 @@ where .map_value(lhs_val.into_pointer_value(), None); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } + get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) }; let rhs_elem = if rhs_scalar { @@ -465,7 +520,7 @@ where .map_value(rhs_val.into_pointer_value(), None); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } + get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) }; value_fn(generator, ctx, (lhs_elem, rhs_elem)) @@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); if cfg!(debug_assertions) { @@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() }; let idx0 = unsafe { let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() }; let idx1 = unsafe { let idx1 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() }; let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; @@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( generator, ctx, None, - llvm_i32.const_zero(), + llvm_usize.const_zero(), (common_dim, false), |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - let ab_idx = generator.gen_array_var_alloc( ctx, - llvm_i32.into(), + llvm_usize.into(), llvm_usize.const_int(2, false), None, )?; @@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); // Dimensions are reversed in the transposed array let out = create_ndarray_dyn_shape( @@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ); // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); + let out_sz = out.size(generator, ctx); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), @@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n1_sz = n1.size(generator, ctx); + let n2_sz = n2.size(generator, ctx); ctx.make_assert( generator, diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 4c5be432..e47876c6 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -5,8 +5,8 @@ use inkwell::{ }; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ irrt, @@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_ndarray_calc_size( - generator, - ctx, - &self.as_slice_value(ctx, generator), - (None, None), - ) + irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) } } @@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - idx.get_type(), - "", - ) - .unwrap(); - let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[idx], - name.unwrap_or_default(), - ) - .unwrap() - }; + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } @@ -769,52 +747,28 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); - let indices_elem_ty = unsafe { - indices - .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type() - }; - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() + let indices = TypedArrayLikeAdapter::from( + indices.as_slice_value(ctx, generator), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ); - let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices); - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - index.get_type(), - "", - ) - .unwrap(); - let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); + let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices( + generator, ctx, *self.0, &indices, + ); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - }; - // TODO: Current implementation is transparent + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. ctx.builder .build_pointer_cast( ptr, BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } From 2f0847d77b428d5c38dcb0fbc7abb3724a202f52 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 16:48:31 +0800 Subject: [PATCH 12/46] [core] codegen/types: Refactor ProxyType - Add alloca_type() function to obtain the type that should be passed into a `build_alloca` call - Provide default implementations for raw_alloca and array_alloca - Add raw_alloca_var and array_alloca_var to distinguish alloca instructions placed at the front of the function vs at the current builder location --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/types/list.rs | 57 +++++++---------- nac3core/src/codegen/types/mod.rs | 58 +++++++++++++++--- .../src/codegen/types/ndarray/contiguous.rs | 61 ++++++++----------- .../src/codegen/types/ndarray/indexing.rs | 56 +++++++---------- nac3core/src/codegen/types/ndarray/mod.rs | 60 ++++++++---------- nac3core/src/codegen/types/ndarray/nditer.rs | 61 +++++++++---------- nac3core/src/codegen/types/range.rs | 51 ++++++---------- nac3core/src/codegen/types/utils/slice.rs | 61 ++++++++----------- .../src/codegen/values/ndarray/contiguous.rs | 2 +- .../src/codegen/values/ndarray/indexing.rs | 2 +- 11 files changed, 224 insertions(+), 247 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c3dfc80e..4b781d26 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1108,7 +1108,7 @@ pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( // List structure; type { ty*, size_t } let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca(generator, ctx, name); + let list = arr_ty.alloca_var(generator, ctx, name); let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); list.store_size(ctx, generator, length); diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 8de30867..6608a808 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,13 +1,12 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::codegen::{ - values::{ArraySliceValue, ListValue, ProxyValue}, + values::{ListValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -113,15 +112,33 @@ impl<'ctx> ListType<'ctx> { } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.llvm_usize, name, ) @@ -162,36 +179,8 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 03d6d387..98bd43ba 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -57,24 +57,66 @@ pub trait ProxyType<'ctx>: Into { llvm_ty: Self::Base, ) -> Result<(), String>; - /// Creates a new value of this type by invoking `alloca`, returning a [`PointerValue`] instance - /// representing the allocated value. - fn raw_alloca( + /// Returns the type that should be used in `alloca` IR statements. + fn alloca_type(&self) -> impl BasicType<'ctx>; + + /// Creates a new value of this type by invoking `alloca` at the current builder location, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> PointerValue<'ctx> { + ctx.builder + .build_alloca(self.alloca_type().as_basic_type_enum(), name.unwrap_or_default()) + .unwrap() + } + + /// Creates a new value of this type by invoking `alloca` at the beginning of the function, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> PointerValue<'ctx>; + ) -> PointerValue<'ctx> { + generator.gen_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), name).unwrap() + } - /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the - /// resulting array. - fn array_alloca( + /// Creates a new array value of this type by invoking `alloca` at the current builder location, + /// returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca( + self.alloca_type().as_basic_type_enum(), + size, + name.unwrap_or_default(), + ) + .unwrap(), + size, + name, + ) + } + + /// Creates a new array value of this type by invoking `alloca` at the beginning of the + /// function, returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, size: IntValue<'ctx>, name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx>; + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), size, name) + .unwrap() + } /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 4be55474..e5fb8cdc 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -16,7 +16,7 @@ use crate::{ }, ProxyType, }, - values::{ndarray::ContiguousNDArrayValue, ArraySliceValue, ProxyValue}, + values::{ndarray::ContiguousNDArrayValue, ProxyValue}, CodeGenContext, CodeGenerator, }, toplevel::numpy::unpack_ndarray_var_tys, @@ -157,16 +157,37 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { Self { ty: ptr_ty, item, llvm_usize } } - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.item, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.item, self.llvm_usize, name, @@ -213,36 +234,8 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 6bbd3e8b..644e173c 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -90,15 +90,33 @@ impl<'ctx> NDIndexType<'ctx> { Self { ty: ptr_ty, llvm_usize } } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.llvm_usize, name, ) @@ -114,7 +132,7 @@ impl<'ctx> NDIndexType<'ctx> { ) -> ArraySliceValue<'ctx> { // Allocate the LLVM ndindices. let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false); - let ndindices = self.array_alloca(generator, ctx, num_ndindices, None); + let ndindices = self.array_alloca_var(generator, ctx, num_ndindices, None); // Initialize all of them. for (i, in_ndindex) in in_ndindices.iter().enumerate() { @@ -171,36 +189,8 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index b655f352..3886ce84 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -14,7 +14,7 @@ use super::{ }; use crate::{ codegen::{ - values::{ndarray::NDArrayValue, ArraySliceValue, ProxyValue, TypedArrayLikeMutator}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeMutator}, {CodeGenContext, CodeGenerator}, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, @@ -182,15 +182,35 @@ impl<'ctx> NDArrayType<'ctx> { } /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.dtype, self.ndims, self.llvm_usize, @@ -214,7 +234,7 @@ impl<'ctx> NDArrayType<'ctx> { ndims: IntValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { - let ndarray = self.alloca(generator, ctx, name); + let ndarray = self.alloca_var(generator, ctx, name); let itemsize = ctx .builder @@ -425,36 +445,8 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index ed98aa4b..7ce8ed79 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -109,8 +109,31 @@ impl<'ctx> NDIterType<'ctx> { self.llvm_usize } + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + parent, + indices, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -119,7 +142,7 @@ impl<'ctx> NDIterType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), parent, indices, self.llvm_usize, @@ -140,7 +163,7 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - let nditer = self.raw_alloca(generator, ctx, None); + let nditer = self.raw_alloca_var(generator, ctx, None); let ndims = ndarray.load_ndims(ctx); // The caller has the responsibility to allocate 'indices' for `NDIter`. @@ -198,36 +221,8 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index dc241b6e..bdd4e79c 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,13 +1,12 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::codegen::{ - values::{ArraySliceValue, ProxyValue, RangeValue}, + values::{ProxyValue, RangeValue}, {CodeGenContext, CodeGenerator}, }; @@ -78,15 +77,29 @@ impl<'ctx> RangeType<'ctx> { } /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(self.raw_alloca(ctx, name), name) + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), name, ) } @@ -126,36 +139,8 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { Self::is_representable(llvm_ty) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index dd7643f7..fa5a3474 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + values::IntValue, AddressSpace, }; use itertools::Itertools; @@ -15,7 +15,7 @@ use crate::codegen::{ }, ProxyType, }, - values::{utils::SliceValue, ArraySliceValue, ProxyValue}, + values::{utils::SliceValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -154,16 +154,35 @@ impl<'ctx> SliceType<'ctx> { self.int_ty } - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.int_ty, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.int_ty, self.llvm_usize, name, @@ -210,36 +229,8 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 87e2f1d8..f3b03dd1 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -118,7 +118,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca(generator, ctx, self.name); + .alloca_var(generator, ctx, self.name); // Set ndims and shape. let ndims = self diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 69c00807..3d575028 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -248,7 +248,7 @@ impl<'ctx> RustNDIndex<'ctx> { RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) - .alloca(generator, ctx, None); + .alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( From 1ffe2fcc7ff826f20e9b5c127ee1133cdca54aa2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jan 2025 15:28:28 +0800 Subject: [PATCH 13/46] [core] irrt: Minor reformat --- nac3core/irrt/irrt/math.hpp | 2 ++ nac3core/irrt/irrt/string.hpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nac3core/irrt/irrt/math.hpp b/nac3core/irrt/irrt/math.hpp index 1872f564..9dc1377e 100644 --- a/nac3core/irrt/irrt/math.hpp +++ b/nac3core/irrt/irrt/math.hpp @@ -1,5 +1,7 @@ #pragma once +#include "irrt/int_types.hpp" + namespace { // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index db3ad7fd..229b7509 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -5,7 +5,7 @@ namespace { template bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { - if (len1 != len2){ + if (len1 != len2) { return 0; } return __builtin_memcmp(str1, str2, static_cast(len1)) == 0; From 805a9d23b353d078c18deb8902f4b77854a7167f Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 3 Jan 2025 13:58:46 +0800 Subject: [PATCH 14/46] [core] codegen: Add derive(Copy, Clone) to TypedArrayLikeAdapter --- nac3core/src/codegen/values/array.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 55e91b21..b756f278 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -208,6 +208,7 @@ pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntV } /// An adapter for constraining untyped array values as typed values. +#[derive(Copy, Clone)] pub struct TypedArrayLikeAdapter< 'ctx, G: CodeGenerator + ?Sized, From 822f9d33f86693df45e532bde65165751cf9886e Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 13 Dec 2024 16:05:54 +0800 Subject: [PATCH 15/46] [core] codegen: Refactor ListType to use derive(StructFields) --- nac3core/src/codegen/expr.rs | 68 +++--- nac3core/src/codegen/numpy.rs | 14 +- nac3core/src/codegen/types/list.rs | 262 +++++++++++++++++++----- nac3core/src/codegen/types/structure.rs | 15 ++ nac3core/src/codegen/values/list.rs | 64 ++---- 5 files changed, 285 insertions(+), 138 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4b781d26..8d7f8e35 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1090,33 +1090,6 @@ pub fn destructure_range<'ctx>( (start, end, step) } -/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting -/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. -/// -/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element -/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to -/// generate a sized list with an unknown element type. -pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Option>, - length: IntValue<'ctx>, - name: Option<&'ctx str>, -) -> ListValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ty.unwrap_or(llvm_usize.into()); - - // List structure; type { ty*, size_t } - let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca_var(generator, ctx, name); - - let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); - list.store_size(ctx, generator, length); - list.create_data(ctx, llvm_elem_ty, None); - - list -} - /// Generates LLVM IR for a [list comprehension expression][expr]. pub fn gen_comprehension<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1189,12 +1162,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = allocate_list( + list = ListType::new(generator, ctx.ctx, elem_ty).construct( generator, ctx, - Some(elem_ty), list_alloc_size.into_int_value(), - Some("listcomp.addr"), + Some("listcomp"), ); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); @@ -1241,7 +1213,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + list = ListType::new(generator, ctx.ctx, elem_ty).construct( + generator, + ctx, + length, + Some("listcomp"), + ); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1406,7 +1383,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); + let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) + .construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1493,10 +1471,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = allocate_list( + let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( generator, ctx, - Some(elem_llvm_ty), ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), None, ); @@ -2553,7 +2530,20 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Some(elements[0].get_type()) }; let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); - let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); + let arr_str_ptr = if let Some(ty) = ty { + ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("list"), + ) + } else { + ListType::new_untyped(generator, ctx.ctx).construct_empty( + generator, + ctx, + Some("list"), + ) + }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { let elem_ptr = arr_ptr.ptr_offset( @@ -3031,8 +3021,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = - allocate_list(generator, ctx, Some(ty), length, Some("ret")); + let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("ret"), + ); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9328bb83..9b5af0f1 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, + types::{BasicType, BasicTypeEnum, PointerType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -639,17 +639,17 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); + let list_elem_ty = list_ty.element_type().unwrap(); let ndims = llvm_usize.const_int(1, false); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Getting ndims for list[ndarray] not supported") @@ -670,10 +670,10 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let list_elem_ty = src_lst.get_type().element_type(); + let list_elem_ty = src_lst.get_type().element_type().unwrap(); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { // The stride of elements in this dimension, i.e. the number of elements between arr[i] @@ -733,7 +733,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( )?; } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Not implemented for list[ndarray]") diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 6608a808..337d049c 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,68 +1,113 @@ use inkwell::{ - context::Context, + context::{AsContextRef, Context}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - AddressSpace, + values::{IntValue, PointerValue}, + AddressSpace, IntPredicate, OptimizationLevel, }; +use itertools::Itertools; + +use nac3core_derive::StructFields; use super::ProxyType; -use crate::codegen::{ - values::{ListValue, ProxyValue}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + types::structure::{ + check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + }, + values::{ListValue, ProxyValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, }; /// Proxy type for a `list` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct ListType<'ctx> { ty: PointerType<'ctx>, + item: Option>, llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ListStructFields<'ctx> { + /// Array pointer to content. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub items: StructField<'ctx, PointerValue<'ctx>>, + + /// Number of items in the array. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ListStructFields<'ctx> { + #[must_use] + pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + ListStructFields { + items: StructField::create( + &mut counter, + "items", + item.ptr_type(AddressSpace::default()), + ), + len: StructField::create(&mut counter, "len", llvm_usize), + } + } +} + impl<'ctx> ListType<'ctx> { /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. pub fn is_representable( llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Result<(), String> { - let llvm_list_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); - }; - if llvm_list_ty.count_fields() != 2 { - return Err(format!( - "Expected 2 fields in `list`, got {}", - llvm_list_ty.count_fields() - )); - } + let ctx = llvm_ty.get_context(); - let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); - let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); + let llvm_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); }; - let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); - let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); - }; - if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width() - )); - } + let fields = ListStructFields::new(ctx, llvm_usize); - Ok(()) + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { + ListStructFields::new_typed(item, llvm_usize) + } + + /// See [`ListType::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, _ctx: &impl AsContextRef<'ctx>) -> ListStructFields<'ctx> { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( ctx: &'ctx Context, - element_type: BasicTypeEnum<'ctx>, + element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - // struct List { data: T*, size: size_t } - let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()]; + let element_type = element_type.unwrap_or(llvm_usize.into()); + + let field_tys = + Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } @@ -75,9 +120,50 @@ impl<'ctx> ListType<'ctx> { element_type: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - ListType::from_type(llvm_list, llvm_usize) + Self { ty: llvm_list, item: Some(element_type), llvm_usize } + } + + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + + Self { ty: llvm_list, item: None, llvm_usize } + } + + /// Creates an [`ListType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `item_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { + None + } else { + Some(ctx.get_llvm_type(generator, elem_type)) + }; + + Self { + ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), + item: llvm_elem_type, + llvm_usize, + } } /// Creates an [`ListType`] from a [`PointerType`]. @@ -85,30 +171,39 @@ impl<'ctx> ListType<'ctx> { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - ListType { ty: ptr_ty, llvm_usize } + let ctx = ptr_ty.get_context(); + + // We are just searching for the index off a field - Slot an arbitrary element type in. + let item_field_idx = + Self::fields(ctx.i8_type().into(), llvm_usize).index_of_field(|f| f.items); + let item = unsafe { + ptr_ty + .get_element_type() + .into_struct_type() + .get_field_type_at_index_unchecked(item_field_idx) + .into_pointer_type() + .get_element_type() + }; + let item = BasicTypeEnum::try_from(item).unwrap_or_else(|()| { + panic!( + "Expected BasicTypeEnum for list element type, got {}", + ptr_ty.get_element_type().print_to_string() + ) + }); + + ListType { ty: ptr_ty, item: Some(item), llvm_usize } } /// Returns the type of the `size` field of this `list` type. #[must_use] pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(1) - .map(BasicTypeEnum::into_int_type) - .unwrap() + self.llvm_usize } /// Returns the element type of this `list` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> Option> { + self.item } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. @@ -144,6 +239,73 @@ impl<'ctx> ListType<'ctx> { ) } + /// Allocates a [`ListValue`] on the stack using `item` of this [`ListType`] instance. + /// + /// The returned list will contain: + /// + /// - `data`: Allocated with `len` number of elements. + /// - `len`: Initialized to the value of `len` passed to this function. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let len = ctx.builder.build_int_z_extend(len, self.llvm_usize, "").unwrap(); + + // Generate a runtime assertion if allocating a non-empty list with unknown element type + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None && self.item.is_none() { + let len_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, len, self.llvm_usize.const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + len_eqz, + "0:AssertionError", + "Cannot allocate a non-empty list with unknown element type", + [None, None, None], + ctx.current_loc, + ); + } + + let plist = self.alloca_var(generator, ctx, name); + plist.store_size(ctx, generator, len); + + let item = self.item.unwrap_or(self.llvm_usize.into()); + plist.create_data(ctx, item, None); + + plist + } + + /// Convenience function for creating a list with zero elements. + /// + /// This function is preferred over [`ListType::construct`] if the length is known to always be + /// 0, as this function avoids injecting an IR assertion for checking if a non-empty untyped + /// list is being allocated. + /// + /// The returned list will contain: + /// + /// - `data`: Initialized to `(T*) 0`. + /// - `len`: Initialized to `0`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + let plist = self.alloca_var(generator, ctx, name); + + plist.store_size(ctx, generator, self.llvm_usize.const_zero()); + plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); + + plist + } + /// Converts an existing value into a [`ListValue`]. #[must_use] pub fn map_value( diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 4d6dcaf7..87781d11 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -5,6 +5,7 @@ use inkwell::{ types::{BasicTypeEnum, IntType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, }; +use itertools::Itertools; use crate::codegen::CodeGenContext; @@ -55,6 +56,20 @@ pub trait StructFields<'ctx>: Eq + Copy { { self.into_vec().into_iter() } + + /// Returns the field index of a field in this structure. + fn index_of_field(&self, name: impl FnOnce(&Self) -> StructField<'ctx, V>) -> u32 + where + V: BasicValue<'ctx> + TryFrom, Error = ()>, + { + let field_name = name(self).name; + self.index_of_field_name(field_name).unwrap() + } + + /// Returns the field index of a field with the given name in this structure. + fn index_of_field_name(&self, field_name: &str) -> Option { + self.iter().find_position(|(name, _)| *name == field_name).map(|(idx, _)| idx as u32) + } } /// A single field of an LLVM structure. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 549bfe3f..bd115a2d 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::ListType, + types::{structure::StructField, ListType}, {CodeGenContext, CodeGenerator}, }; @@ -42,48 +42,26 @@ impl<'ctx> ListValue<'ctx> { ListValue { value: ptr, llvm_usize, name } } + fn items_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).items + } + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Returns the pointer to the field storing the size of this `list`. - fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + self.items_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); + self.items_field(ctx).set(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element /// type `elem_ty` and `size`. /// - /// If `size` is [None], the size stored in the field of this instance is used instead. + /// If `size` is [None], the size stored in the field of this instance is used instead. If + /// `size` is resolved to `0` at runtime, `(T*) 0` will be assigned to `data`. pub fn create_data( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -114,6 +92,10 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } + fn len_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).len + } + /// Stores the `size` of this `list` into this instance. pub fn store_size( &self, @@ -123,22 +105,16 @@ impl<'ctx> ListValue<'ctx> { ) { debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); - let psize = self.ptr_to_size(ctx); - ctx.builder.build_store(psize, size).unwrap(); + self.len_field(ctx).set(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. - pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let psize = self.ptr_to_size(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.size"))) - .unwrap_or_default(); - - ctx.builder - .build_load(psize, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() + pub fn load_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> IntValue<'ctx> { + self.len_field(ctx).get(ctx, self.value, name) } } From 7d02f5833d5eea92545607f39e7368463987a987 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 13:56:08 +0800 Subject: [PATCH 16/46] [core] codegen: Implement Tuple{Type,Value} --- nac3core/src/codegen/builtin_fns.rs | 80 ++++-------- nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/tuple.rs | 184 +++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/tuple.rs | 85 +++++++++++++ 6 files changed, 303 insertions(+), 54 deletions(-) create mode 100644 nac3core/src/codegen/types/tuple.rs create mode 100644 nac3core/src/codegen/values/tuple.rs diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index b21e721e..32b95a75 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -14,7 +14,7 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, TupleType}, values::{ ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, @@ -1868,34 +1868,6 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( }) } -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: &[BasicValueEnum<'ctx>], -) -> PointerValue<'ctx> { - let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, *v).unwrap(); - } - } - out_ptr -} - /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -1973,10 +1945,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( None, ); - let q = q.as_base_value().into(); - let r = r.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[q, r]); - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + let q = q.as_base_value().as_basic_value_enum(); + let r = r.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) + .construct_from_objects(ctx, [q, r], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_svd` linalg function @@ -2031,12 +2004,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( None, ); - let u = u.as_base_value().into(); - let s = s.as_base_value().into(); - let vh = vh.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[u, s, vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + let u = u.as_base_value().as_basic_value_enum(); + let s = s.as_base_value().as_basic_value_enum(); + let vh = vh.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + .construct_from_objects(ctx, [u, s, vh], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_inv` linalg function @@ -2158,10 +2131,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( None, ); - let l = l.as_base_value().into(); - let u = u.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[l, u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + let l = l.as_base_value().as_basic_value_enum(); + let u = u.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) + .construct_from_objects(ctx, [l, u], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -2293,10 +2267,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( None, ); - let t = t.as_base_value().into(); - let z = z.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[t, z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + let t = t.as_base_value().as_basic_value_enum(); + let z = z.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) + .construct_from_objects(ctx, [t, z], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2337,8 +2312,9 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( None, ); - let h = h.as_base_value().into(); - let q = q.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[h, q]); - Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) + let h = h.as_base_value().as_basic_value_enum(); + let q = q.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) + .construct_from_objects(ctx, [h, q], None); + Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1e0fb268..2ce3c9ab 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -42,7 +42,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -574,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - ctx.struct_type(&fields, false).into() + TupleType::new(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 98bd43ba..0a31d6a5 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -28,11 +28,13 @@ use super::{ }; pub use list::*; pub use range::*; +pub use tuple::*; mod list; pub mod ndarray; mod range; pub mod structure; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs new file mode 100644 index 00000000..ccb63b4a --- /dev/null +++ b/nac3core/src/codegen/types/tuple.rs @@ -0,0 +1,184 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, StructType}, + values::BasicValueEnum, +}; +use itertools::Itertools; + +use super::ProxyType; +use crate::{ + codegen::{ + values::{ProxyValue, TupleValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct TupleType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> TupleType<'ctx> { + /// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not. + pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> { + Ok(()) + } + + /// Creates an LLVM type corresponding to the expected structure of a tuple. + #[must_use] + fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { + ctx.struct_type(tys, false) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + + /// Creates an [`TupleType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { + panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty)); + }; + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize } + } + + /// Creates an [`TupleType`] from a [`StructType`]. + #[must_use] + pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(struct_ty).is_ok()); + + TupleType { ty: struct_ty, llvm_usize } + } + + /// Returns the number of elements present in this [`TupleType`]. + #[must_use] + pub fn num_elements(&self) -> u32 { + self.ty.count_fields() + } + + /// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of + /// range. + #[must_use] + pub fn type_at_index(&self, index: u32) -> Option> { + if index < self.num_elements() { + Some(unsafe { self.type_at_index_unchecked(index) }) + } else { + None + } + } + + /// Returns the type of the tuple element at the given `index`. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid. + #[must_use] + pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> { + self.ty.get_field_type_at_index_unchecked(index) + } + + /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. + #[must_use] + pub fn construct( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name) + } + + /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of + /// objects. + #[must_use] + pub fn construct_from_objects>>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + objects: I, + name: Option<&'ctx str>, + ) -> >::Value { + let values = objects.into_iter().collect_vec(); + + assert_eq!(values.len(), self.num_elements() as usize); + assert!(values + .iter() + .enumerate() + .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); + + let mut value = self.construct(ctx, name); + for (i, val) in values.into_iter().enumerate() { + value.store_element(ctx, i as u32, val); + } + + value + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type Base = StructType<'ctx>; + type Value = TupleValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected struct type, got {llvm_ty:?}")) + } + } + + fn is_representable( + _generator: &G, + _ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: TupleType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 032f0417..c789fe0f 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -5,11 +5,13 @@ use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use range::*; +pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs new file mode 100644 index 00000000..5167e479 --- /dev/null +++ b/nac3core/src/codegen/values/tuple.rs @@ -0,0 +1,85 @@ +use inkwell::{ + types::IntType, + values::{BasicValue, BasicValueEnum, StructValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::TupleType, CodeGenContext}; + +#[derive(Copy, Clone)] +pub struct TupleValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> TupleValue<'ctx> { + /// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an + /// instance. + pub fn is_representable( + value: StructValue<'ctx>, + _llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + TupleType::is_representable(value.get_type()) + } + + /// Creates an [`TupleValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(value, llvm_usize).is_ok()); + + Self { value, llvm_usize, name } + } + + /// Stores a value into the tuple element at the given `index`. + pub fn store_element( + &mut self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + element: impl BasicValue<'ctx>, + ) { + assert_eq!(element.as_basic_value_enum().get_type(), unsafe { + self.get_type().type_at_index_unchecked(index) + }); + + let new_value = ctx + .builder + .build_insert_value(self.value, element, index, self.name.unwrap_or_default()) + .unwrap(); + self.value = new_value.into_struct_value(); + } + + /// Loads a value from the tuple element at the given `index`. + pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + ctx.builder + .build_extract_value( + self.value, + index, + &format!("{}[{{i}}]", self.name.unwrap_or("tuple")), + ) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type Base = StructValue<'ctx>; + type Type = TupleType<'ctx>; + + fn get_type(&self) -> Self::Type { + TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: TupleValue<'ctx>) -> Self { + value.as_base_value() + } +} From 5880f964bbf2b134690dff6a0d04c533cc45f2c9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 15:26:18 +0800 Subject: [PATCH 17/46] [core] codegen/ndarray: Reimplement np_{zeros,ones,full,empty} Based on 792374fa: core/ndstrides: implement np_{zeros,ones,full,empty}. --- nac3core/src/codegen/numpy.rs | 291 +++--------------- nac3core/src/codegen/types/ndarray/factory.rs | 146 +++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- nac3core/src/codegen/values/ndarray/mod.rs | 18 ++ nac3core/src/codegen/values/ndarray/shape.rs | 152 +++++++++ 6 files changed, 374 insertions(+), 241 deletions(-) create mode 100644 nac3core/src/codegen/types/ndarray/factory.rs create mode 100644 nac3core/src/codegen/values/ndarray/shape.rs diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9b5af0f1..9c57919c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -3,7 +3,6 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::Itertools; use nac3parser::ast::{Operator, StrRef}; @@ -19,17 +18,28 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ndarray::NDArrayType, ListType, ProxyType}, + types::{ + ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, + }, + ListType, ProxyType, + }, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, - ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{extract_ndims, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::{ magic_methods::Binop, typedef::{FunSignature, Type, TypeEnum}, @@ -174,132 +184,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let shape = (0..ndims) - .map(|dim_i| { - ctx.builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .map(BasicValueEnum::into_int_value) - .map(|v| { - ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap() - }) - .unwrap() - }) - .collect_vec(); - - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - let shape_int = - ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap(); - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -529,107 +413,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - })?; - - Ok(ndarray) -} - /// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( generator: &G, @@ -1752,8 +1535,15 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_empty(generator, context, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1770,8 +1560,15 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_zeros(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.ones`. @@ -1788,8 +1585,15 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_ones(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.full`. @@ -1809,8 +1613,15 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + Ok(ndarray.as_base_value()) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs new file mode 100644 index 00000000..13aae8cd --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -0,0 +1,146 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use super::NDArrayType; +use crate::{ + codegen::{ + irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// Get the zero value in `np.zeros()` of a `dtype`. +pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayType<'ctx> { + /// Create an ndarray like + /// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html). + pub fn construct_numpy_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, name); + + // Validate `shape` + irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape); + + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + ndarray + } + + /// Create an ndarray like + /// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html). + pub fn construct_numpy_full( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + fill_value: BasicValueEnum<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_numpy_empty(generator, ctx, shape, name); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like + /// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html). + pub fn construct_numpy_zeros( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_zero_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } + + /// Create an ndarray like + /// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html). + pub fn construct_numpy_ones( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_one_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 3886ce84..89241618 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -25,6 +25,7 @@ pub use indexing::*; pub use nditer::*; mod contiguous; +pub mod factory; mod indexing; mod nditer; diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 7ce8ed79..9b71693a 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,8 +163,13 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { + assert!( + ndarray.get_type().ndims().is_some(), + "NDIter requires ndims of NDArray to be known." + ); + let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = ndarray.load_ndims(ctx); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index e47876c6..ffde76c9 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -23,6 +23,7 @@ pub use nditer::*; mod contiguous; mod indexing; mod nditer; +pub mod shape; mod view; /// Proxy type for accessing an `NDArray` value in LLVM. @@ -397,6 +398,23 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); } + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |_, ctx, _, nditer| { + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs new file mode 100644 index 00000000..190a1e4f --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -0,0 +1,152 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use crate::{ + codegen::{ + stmt::gen_for_callback_incrementing, + types::{ListType, TupleType}, + values::{ + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to +/// `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), +) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let zero = llvm_usize.const_zero(); + let one = llvm_usize.const_int(1, false); + + // The result `list` to return. + match &*ctx.unifier.get_ty_immutable(input_seq_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_pointer_value(), None); + + let len = input_seq.load_size(ctx, None); + // TODO: Find a way to remove this mid-BB allocation + let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap(); + let result = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(result, len, None), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_callback_incrementing( + generator, + ctx, + None, + zero, + (len, false), + |generator, ctx, _, i| { + // Load the i-th int32 in the input sequence + let int = unsafe { + input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value() + }; + + // Cast to SizeT + let int = + ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + // Store + unsafe { result.set_typed_unchecked(ctx, generator, &i, int) }; + + Ok(()) + }, + one, + ) + .unwrap(); + + result + } + + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_struct_value(), None); + + let len = input_seq.get_type().num_elements(); + + let result = generator + .gen_array_var_alloc( + ctx, + llvm_usize.into(), + llvm_usize.const_int(u64::from(len), false), + None, + ) + .unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + for i in 0..input_seq.get_type().num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_seq.load_element(ctx, i).into_int_value(); + let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + unsafe { + result.set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(u64::from(i), false), + int, + ); + } + } + + result + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + let input_int = input_seq.into_int_value(); + + let len = one; + let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let int = + ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap(); + + // Storing into result[0] + unsafe { + result.set_typed_unchecked(ctx, generator, &zero, int); + } + + result + } + + _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)), + } +} From 26f1428739568968aa335e37bab33909ee810931 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 14:21:13 +0800 Subject: [PATCH 18/46] [core] codegen: Refactor len() Based on 54a842a9: core/ndstrides: implement len(ndarray) & refactor len() --- nac3core/src/codegen/builtin_fns.rs | 61 ++++++++++++----------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 32b95a75..54650ab3 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -14,9 +14,9 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::{ndarray::NDArrayType, TupleType}, + types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, + ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -55,42 +55,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TTuple { .. } => { + let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_struct_value(), None); + llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) .map_value(arg.into_pointer_value(), None); - - let ndims = arg.shape().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + ctx.builder + .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") + .unwrap() } - _ => codegen_unreachable!(ctx), + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_pointer_value(), None); + ctx.builder + .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") + .unwrap() + } + + _ => unsupported_type(ctx, "len", &[arg_ty]), } }) } From fadadd7505c99b3accbd942200aa657d64ad28b1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 20 Aug 2024 14:51:40 +0800 Subject: [PATCH 19/46] [core] codegen/ndarray: Reimplement np_array() Based on 8f0084ac: core/ndstrides: implement np_array() It also checks for inconsistent dimensions if the input is a list. e.g., rejecting `[[1.0, 2.0], [3.0]]`. However, currently only `np_array(, copy=False)` and `np_array (, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(, copy=False)` behaves like NumPy's `np.array(, copy=None)`. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/list.hpp | 15 + nac3core/irrt/irrt/ndarray/array.hpp | 132 ++++++ nac3core/src/codegen/irrt/ndarray/array.rs | 80 ++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/numpy.rs | 460 +------------------- nac3core/src/codegen/types/ndarray/array.rs | 245 +++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/values/list.rs | 19 +- nac3core/src/codegen/values/ndarray/mod.rs | 2 +- 10 files changed, 512 insertions(+), 445 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/array.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/array.rs create mode 100644 nac3core/src/codegen/types/ndarray/array.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 0d069869..57f60d52 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,3 +9,4 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/array.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index 28543945..1edfe498 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -2,6 +2,21 @@ #include "irrt/int_types.hpp" #include "irrt/math_util.hpp" +#include "irrt/slice.hpp" + +namespace { +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template +struct List { + uint8_t* items; + SizeT len; +}; +} // namespace extern "C" { // Handle list assignment and dropping part of the list when diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 00000000..126669e7 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/list.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::array { +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], + * [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the + * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because + * of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List* list, SizeT ndims, SizeT* shape) { + if (shape[axis] == -1) { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } else { + // Dimension is specified. Check. + if (shape[axis] != list->len) { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) { + // `list` has type `list[ItemType]` + // Do nothing + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + for (SizeT i = 0; i < list->len; i++) { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template +void set_and_validate_list_shape(List* list, SizeT ndims, SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT* index, List* list, NDArray* ndarray) { + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) { + if (!ndarray::basic::is_c_contiguous(ndarray)) { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t* dst = static_cast(ndarray->data) + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + + for (SizeT i = 0; i < list->len; i++) { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template +void write_list_to_array(List* list, NDArray* ndarray) { + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace ndarray::array +} // namespace + +extern "C" { +using namespace ndarray::array; + +void __nac3_ndarray_array_set_and_validate_list_shape(List* list, int32_t ndims, int32_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_set_and_validate_list_shape64(List* list, int64_t ndims, int64_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_write_list_to_array(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} + +void __nac3_ndarray_array_write_list_to_array64(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs new file mode 100644 index 00000000..931b66cb --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -0,0 +1,80 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_array_set_and_validate_list_shape`. +/// +/// Deduces the target shape of the `ndarray` from the provided `list`, raising an exception if +/// there is any issue with the resultant `shape`. +/// +/// `shape` must be pre-allocated by the caller of this function to `[usize; ndims]`, and must be +/// initialized to all `-1`s. +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndims: IntValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + assert_eq!(ndims.get_type(), llvm_usize); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_set_and_validate_list_shape", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_array_write_list_to_array`. +/// +/// Copies the contents stored in `list` into `ndarray`. +/// +/// The `ndarray` must fulfill the following preconditions: +/// +/// - `ndarray.itemsize`: Must be initialized. +/// - `ndarray.ndims`: Must be initialized. +/// - `ndarray.shape`: Must be initialized. +/// - `ndarray.data`: Must be allocated and contiguous. +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) { + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_write_list_to_array", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndarray.as_base_value().into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56017c94..307ec6bb 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -16,10 +16,12 @@ use crate::codegen::{ }, CodeGenContext, CodeGenerator, }; +pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +mod array; mod basic; mod indexing; mod iter; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9c57919c..09d848dc 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,7 +1,7 @@ use inkwell::{ - types::{BasicType, BasicTypeEnum, PointerType}, + types::BasicType, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + IntPredicate, OptimizationLevel, }; use nac3parser::ast::{Operator, StrRef}; @@ -18,12 +18,9 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ - ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, - ListType, ProxyType, + types::ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, }, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, @@ -35,14 +32,10 @@ use super::{ }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{ - helper::{extract_ndims, PrimDef}, - numpy::unpack_ndarray_var_tys, - DefinitionId, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{FunSignature, Type}, }, }; @@ -413,394 +406,6 @@ where Ok(res) } -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type().unwrap(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type().unwrap(); - - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - let offset = ctx - .builder - .build_int_mul( - offset, - ctx.builder - .build_int_truncate_or_bit_cast( - dst_arr.get_type().element_type().size_of().unwrap(), - offset.get_type(), - "", - ) - .unwrap(), - "", - ) - .unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_pointer_value( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - let sizeof_elem = - ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_pointer_value( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_elem_ty, - None, - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_representable(object, llvm_usize).is_ok()); - let object = ListValue::from_pointer_value(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = - ListValue::from_pointer_value(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(next_dim, llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_pointer_value( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - /// LLVM-typed implementation for generating the implementation for `ndarray.eye`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1635,26 +1240,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1670,28 +1255,17 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let copy = generator.bool_to_i1(context, copy_arg.into_int_value()); + let ndarray = NDArrayType::from_unifier_type(generator, context, fun.0.ret) + .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.eye`. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs new file mode 100644 index 00000000..87cd002a --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -0,0 +1,245 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, IntValue}, + AddressSpace, +}; + +use crate::{ + codegen::{ + irrt, + stmt::gen_if_else_expr_callback, + types::{ndarray::NDArrayType, ListType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array()`. +fn get_list_object_dtype_and_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + list_ty: Type, +) -> (BasicTypeEnum<'ctx>, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list_ty); + let ndims = arraylike_get_ndims(&mut ctx.unifier, list_ty); + + (ctx.get_llvm_type(generator, dtype), ndims) +} + +impl<'ctx> NDArrayType<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn construct_numpy_array_from_list_copy_true_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert_eq!(dtype, self.dtype); + + let list_value = list.as_i8_list(generator, ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = self.llvm_usize.const_int(ndims_int, false); + let shape = ctx.builder.build_array_alloca(self.llvm_usize, ndims, "").unwrap(); + let shape = ArraySliceValue::from_ptr_val(shape, ndims, None); + let shape = TypedArrayLikeAdapter::from( + shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + irrt::ndarray::call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, &shape, + ); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, name); + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + // Copy all contents from the list. + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( + generator, ctx, list_value, ndarray, + ); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn construct_numpy_array_from_list_copy_none_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + if ndims == 1 { + // `list` is not nested + assert_eq!(ndims, 1); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert_eq!(dtype, self.dtype); + + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + .construct_uninitialized(generator, ctx, name); + + // Set data + let data = ctx + .builder + .build_pointer_cast(list.data().base_ptr(ctx, generator), llvm_pi8, "") + .unwrap(); + ndarray.store_data(ctx, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.shape(); + let list_len = list.load_size(ctx, None); + unsafe { + shape.set_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), list_len); + } + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(generator, ctx); + + ndarray + } else { + // `list` is nested, copy + self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn construct_numpy_array_list_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_none_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn construct_numpy_array_ndarray_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(ndarray.get_type().dtype, self.dtype); + assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self + .ndims + .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.as_base_value())) + }, + |_generator, _ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + ndarray.get_type().map_value(ndarray_val, name) + } + + /// Create a new ndarray like + /// [`np.array()`](https://numpy.org/doc/stable/reference/generated/numpy.array.html). + /// + /// Note that the returned [`NDArrayValue`] may have fewer dimensions than is specified by this + /// instance. Use [`NDArrayValue::atleast_nd`] on the returned value if an `ndarray` instance + /// with the exact number of dimensions is needed. + pub fn construct_numpy_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + match &*ctx.unifier.get_ty_immutable(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) + } + + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 89241618..17bb6adc 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -24,6 +24,7 @@ pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod array; mod contiguous; pub mod factory; mod indexing; diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index bd115a2d..c497f8f8 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::{structure::StructField, ListType}, + types::{structure::StructField, ListType, ProxyType}, {CodeGenContext, CodeGenerator}, }; @@ -116,6 +116,23 @@ impl<'ctx> ListValue<'ctx> { ) -> IntValue<'ctx> { self.len_field(ctx).get(ctx, self.value, name) } + + /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. + #[must_use] + pub fn as_i8_list( + &self, + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> ListValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + + Self::from_pointer_value( + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + self.llvm_usize, + self.name, + ) + } } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index ffde76c9..d4a460a5 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -173,7 +173,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { let data = ctx .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") From acb437919d951e0c14a468a55d2908282cc16044 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 18:01:12 +0800 Subject: [PATCH 20/46] [core] codegen/ndarray: Reimplement np_{eye,identity} Based on fa047d50: core/ndstrides: implement np_identity() and np_eye() --- nac3core/src/codegen/numpy.rs | 107 ++++++------------ nac3core/src/codegen/types/ndarray/factory.rs | 94 ++++++++++++++- 2 files changed, 126 insertions(+), 75 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 09d848dc..703d03ec 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,10 +18,7 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, + types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, @@ -406,55 +403,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let nrows = context + .builder + .build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ncols = context + .builder + .build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let offset = context + .builder + .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") + .unwrap(); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.identity`. @@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let n = context + .builder + .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_identity(generator, context, dtype, n, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.copy`. diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 13aae8cd..300167f7 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -1,4 +1,7 @@ -use inkwell::values::{BasicValueEnum, IntValue}; +use inkwell::{ + values::{BasicValueEnum, IntValue}, + IntPredicate, +}; use super::NDArrayType; use crate::{ @@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( } /// Get the one value in `np.ones()` of a `dtype`. -pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, @@ -143,4 +146,91 @@ impl<'ctx> NDArrayType<'ctx> { let fill_value = ndarray_one_value(generator, ctx, dtype); self.construct_numpy_full(generator, ctx, shape, fill_value, name) } + + /// Create an ndarray like + /// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html). + #[allow(clippy::too_many_arguments)] + pub fn construct_numpy_eye( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + assert_eq!(nrows.get_type(), self.llvm_usize); + assert_eq!(ncols.get_type(), self.llvm_usize); + assert_eq!(offset.get_type(), self.llvm_usize); + + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name); + + // Create data and make the matrix like look np.eye() + unsafe { + ndarray.create_data(generator, ctx); + } + ndarray + .foreach(generator, ctx, |generator, ctx, _, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + let indices = nditer.get_indices(); + + let row_i = unsafe { + indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None) + }; + let col_i = unsafe { + indices.get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(1, false), + None, + ) + }; + + let be_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + ctx.builder.build_int_add(row_i, offset, "").unwrap(), + col_i, + "", + ) + .unwrap(); + let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like + /// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html). + pub fn construct_numpy_identity( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let offset = self.llvm_usize.const_zero(); + self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name) + } } From 9ffa2d6552e5052aba16bdba9b9f081efa138c68 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 09:53:00 +0800 Subject: [PATCH 21/46] [core] codegen/ndarray: Reimplement np_{copy,fill} Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy() --- nac3core/src/codegen/numpy.rs | 57 ++++------------------ nac3core/src/codegen/values/ndarray/mod.rs | 2 + 2 files changed, 11 insertions(+), 48 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 703d03ec..2f899f95 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>( assert!(args.is_empty()); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_copy_impl( - generator, - context, - this_elem_ty, - llvm_this_ty.map_value(this_arg.into_pointer_value(), None), - ) - .map(NDArrayValue::into) + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.fill`. @@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>( assert_eq!(args.len(), 1); let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_fill_flattened( - generator, - context, - llvm_this_ty.map_value(this_arg, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + this.fill(generator, context, value_arg); Ok(()) } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d4a460a5..2907445e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>, ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. self.foreach(generator, ctx, |_, ctx, _, nditer| { let p = nditer.get_pointer(ctx); ctx.builder.build_store(p, value).unwrap(); From 12358c57b1a17592e18fa6f2acd161b87b22269b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:28:56 +0800 Subject: [PATCH 22/46] [core] codegen/ndarray: Implement np_{shape,strides} Based on 40c24486: core/ndstrides: implement np_shape() and np_strides() These functions are not important, but they are handy for debugging. `np.strides()` is not an actual NumPy function, but `ndarray.strides` is used. --- nac3core/src/codegen/values/ndarray/mod.rs | 86 +++++++++++++++++-- nac3core/src/toplevel/builtins.rs | 53 ++++++++++++ nac3core/src/toplevel/helper.rs | 8 ++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 41 ++++++++- nac3standalone/demo/interpret_demo.py | 4 + 10 files changed, 193 insertions(+), 13 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 2907445e..951792fb 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -1,19 +1,23 @@ +use std::iter::repeat_n; + use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }; use crate::codegen::{ irrt, llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, stmt::gen_for_callback_incrementing, type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField}, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; pub use contiguous::*; @@ -417,13 +421,85 @@ impl<'ctx> NDArrayValue<'ctx> { .unwrap(); } + /// Create the shape tuple of this ndarray like + /// [`np.shape()`](https://numpy.org/doc/stable/reference/generated/numpy.shape.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.shape().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + + /// Create the strides tuple of this ndarray like + /// [`.strides`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.strides().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { self.ndims.map(|ndims| ndims == 0) } - /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. /// Otherwise, do nothing and return the ndarray itself. // TODO: Rename to get_unsized_element pub fn split_unsized( diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 36bd85d1..ac3fa08f 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,6 +14,7 @@ use crate::{ builtin_fns, numpy::*, stmt::exn_constructor, + types::ndarray::NDArrayType, values::{ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, @@ -368,6 +369,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1242,6 +1247,54 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.as_base_value().into())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 71c1859b..75a7eabc 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,10 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpShape, + FunNpStrides, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -240,6 +244,10 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 93f2096b..9313448e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index d3301d00..0aa21de1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 911426b9..2490cc75 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index d60daf83..a7230f4d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 517f6846..871a2f89 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 6068f630..8f1c54fc 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::{ cmp::max, collections::{HashMap, HashSet}, convert::{From, TryInto}, - iter::once, + iter::{once, repeat_n}, sync::Arc, }; @@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..5bcf4bb5 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray property getters + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf From 132ba1942f018981d201e34520bea070d6137e5b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:58:38 +0800 Subject: [PATCH 23/46] [core] toplevel: Implement np_size Based on 2c1030d1: core/ndstrides: implement np_size() --- nac3core/src/toplevel/builtins.rs | 35 +++++++++++++++++-- nac3core/src/toplevel/helper.rs | 2 ++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +-- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +-- nac3standalone/demo/interpret_demo.py | 1 + 8 files changed, 43 insertions(+), 9 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ac3fa08f..de5a278e 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -369,7 +369,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpShape | PrimDef::FunNpStrides => { + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { self.build_ndarray_property_getter_function(prim) } @@ -1248,7 +1248,10 @@ impl<'a> BuiltinBuilder<'a> { } fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1257,6 +1260,34 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let size = ctx + .builder + .build_int_truncate_or_bit_cast( + ndarray.size(generator, ctx), + ctx.ctx.i32_type(), + "", + ) + .unwrap(); + Ok(Some(size.into())) + }), + ), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { // The function signatures of `np_shape` an `np_size` are the same. // Mixed together for convenience. diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 75a7eabc..10e77422 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -55,6 +55,7 @@ pub enum PrimDef { FunNpIdentity, // NumPy ndarray property getters + FunNpSize, FunNpShape, FunNpStrides, @@ -245,6 +246,7 @@ impl PrimDef { PrimDef::FunNpIdentity => fun("np_identity", None), // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 9313448e..4650fbbf 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 0aa21de1..b67596d8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 2490cc75..08f254f5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index a7230f4d..ce3b02ed 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 871a2f89..b053b814 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 5bcf4bb5..ca17c3da 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray property getters + module.np_size = np.size module.np_shape = np.shape module.np_strides = lambda ndarray: ndarray.strides From aae41eef6a5de558b4202f076a4cb1785aaaf32b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 11:05:45 +0800 Subject: [PATCH 24/46] [core] toplevel: Add view functions category Based on 9e0f636d: core: categorize np_{transpose,reshape} as 'view functions' --- nac3core/src/toplevel/builtins.rs | 108 +++++++++--------- nac3core/src/toplevel/helper.rs | 12 +- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3standalone/demo/interpret_demo.py | 6 +- 8 files changed, 72 insertions(+), 68 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index de5a278e..276c00c7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -373,6 +373,10 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } + PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_ndarray_view_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -438,10 +442,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { - self.build_np_sp_ndarray_function(prim) - } - PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -1326,6 +1326,55 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build np/sp functions that take as input `NDArray` only + fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpTranspose => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + }), + ), + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpReshape => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }), + ), + + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1813,57 +1862,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build np/sp functions that take as input `NDArray` only - fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); - - match prim { - PrimDef::FunNpTranspose => { - let ndarray_ty = self.unifier.get_fresh_var_with_range( - &[self.ndarray_num_ty], - Some("T".into()), - None, - ); - create_fn_by_codegen( - self.unifier, - &into_var_map([ndarray_ty]), - prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ) - } - - // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and - // the `param_ty` for `create_fn_by_codegen`. - // - // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking - // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], - // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), - - _ => unreachable!(), - } - } - /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 10e77422..9313b13b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -59,6 +59,10 @@ pub enum PrimDef { FunNpShape, FunNpStrides, + // NumPy ndarray view functions + FunNpTranspose, + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -106,8 +110,6 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunNpTranspose, - FunNpReshape, // Linalg functions FunNpDot, @@ -250,6 +252,10 @@ impl PrimDef { PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), + // NumPy NDArray view functions + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), @@ -297,8 +303,6 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunNpTranspose => fun("np_transpose", None), - PrimDef::FunNpReshape => fun("np_reshape", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 4650fbbf..b03b3616 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b67596d8..b4df49c9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 08f254f5..65a6a8ac 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index ce3b02ed..cfedf1f6 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index b053b814..9a77a21c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", ] diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index ca17c3da..56c6126d 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray view functions + module.np_transpose = np.transpose + module.np_reshape = np.reshape + # NumPy NDArray property getters module.np_size = np.size module.np_shape = np.shape @@ -223,8 +227,6 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter - module.np_transpose = np.transpose - module.np_reshape = np.reshape # SciPy Math functions module.sp_spec_erf = special.erf From 8d975b5ff3ecf04c71ba7971aac8c1f7c91dab81 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 11:40:23 +0800 Subject: [PATCH 25/46] [core] codegen/ndarray: Implement np_reshape Based on 926e7e93: core/ndstrides: implement np_reshape() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/reshape.hpp | 97 +++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/irrt/ndarray/reshape.rs | 40 +++ nac3core/src/codegen/numpy.rs | 334 +----------------- nac3core/src/codegen/values/ndarray/view.rs | 77 +++- nac3core/src/toplevel/builtins.rs | 59 +++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3standalone/demo/src/ndarray.py | 55 ++- 13 files changed, 308 insertions(+), 373 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/reshape.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/reshape.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 57f60d52..bb7fb3d4 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,5 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" -#include "irrt/ndarray/array.hpp" \ No newline at end of file +#include "irrt/ndarray/array.hpp" +#include "irrt/ndarray/reshape.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 00000000..b2ad2a5b --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::reshape { +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template +void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) { + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) { + SizeT dim = new_shape[axis_i]; + if (dim < 0) { + if (dim == -1) { + if (neg1_exists) { + // Multiple `-1` found. Throw an error. + raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, + NO_PARAM, NO_PARAM); + } else { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } else { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, + NO_PARAM); + } + } else { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) { + // Let `x` be the unknown dimension + // Solve `x * = ` + if (new_size == 0 && size == 0) { + // `x` has infinitely many solutions + can_reshape = false; + } else if (new_size == 0 && size != 0) { + // `x` has no solutions + can_reshape = false; + } else if (size % new_size != 0) { + // `x` has no integer solutions + can_reshape = false; + } else { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } else { + can_reshape = (new_size == size); + } + + if (!can_reshape) { + raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, + NO_PARAM); + } +} +} // namespace ndarray::reshape +} // namespace + +extern "C" { +void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} + +void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 307ec6bb..f67566b9 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -20,11 +20,13 @@ pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +pub use reshape::*; mod array; mod basic; mod indexing; mod iter; +mod reshape; /// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs new file mode 100644 index 00000000..32de2fa1 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -0,0 +1,40 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ArrayLikeValue, ArraySliceValue}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`. +/// +/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(, new_shape)`, raising an +/// assertion if multiple dimensions are unknown (`-1`). +pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + new_ndims: IntValue<'ctx>, + new_shape: ArraySliceValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(size.get_type(), llvm_usize); + assert_eq!(new_ndims.get_type(), llvm_usize); + assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_reshape_resolve_and_check_new_shape", + ); + infer_and_call_function( + ctx, + &name, + None, + &[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 2f899f95..ce4e2059 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,9 +21,9 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -134,46 +134,6 @@ where Ok(ndarray) } -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - } - - let llvm_dtype = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64)) - .construct_dyn_shape(generator, ctx, shape, None); - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -1455,294 +1415,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( } } -/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. -/// -/// * `x1` - `NDArray` to reshape. -/// * `shape` - The `shape` parameter used to construct the new `NDArray`. -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` -/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// -/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. -pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - shape: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_reshape"; - let (x1_ty, x1) = x1; - let (_, shape) = shape; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = n1.size(generator, ctx); - - let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); - ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); - - let out = match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - // Check for -1 in dimensions - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_list.load_size(ctx, None), false), - |generator, ctx, _, idx| { - let ele = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - ele, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_neg_value = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_neg_value = ctx - .builder - .build_int_add( - num_neg_value, - llvm_usize.const_int(1, false), - "", - ) - .unwrap(); - ctx.builder.build_store(num_neg, num_neg_value).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_value = - ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_value = - ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); - ctx.builder.build_store(acc, acc_value).unwrap(); - Ok(None) - }, - )?; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - // Generate the output shape by filling -1 with `rem` - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, _| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - let dim = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` - - let ndims = shape_tuple.get_type().count_fields(); - // Check for -1 in dims - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_negs = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_negs = ctx - .builder - .build_int_add(num_negs, llvm_usize.const_int(1, false), "") - .unwrap(); - ctx.builder.build_store(num_neg, num_negs).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); - ctx.builder.build_store(acc, acc_val).unwrap(); - Ok(None) - }, - )?; - } - - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - let mut shape = Vec::with_capacity(ndims as usize); - - // Reconstruct shape filling negatives with rem - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - let dim = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value(); - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` - let shape_int = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - shape_int, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(n_sz)), - |_, ctx| { - Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) - }, - )? - .unwrap() - .into_int_value(); - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } - .unwrap(); - - // Only allow one dimension to be negative - let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "can only specify one unknown dimension", - [None, None, None], - ctx.current_loc, - ); - - // The new shape must be compatible with the old shape - let out_sz = out.size(generator, ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), - "0:ValueError", - "cannot reshape array of size {0} into provided shape of size {1}", - [Some(n_sz), Some(out_sz), None], - ctx.current_loc, - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 70a9d659..8334d379 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,9 +1,16 @@ use std::iter::{once, repeat_n}; +use inkwell::values::IntValue; use itertools::Itertools; use crate::codegen::{ - values::ndarray::{NDArrayValue, RustNDIndex}, + irrt, + stmt::gen_if_callback, + types::ndarray::NDArrayType, + values::{ + ndarray::{NDArrayValue, RustNDIndex}, + ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, + }, CodeGenContext, CodeGenerator, }; @@ -33,4 +40,72 @@ impl<'ctx> NDArrayValue<'ctx> { *self } } + + /// Create a reshaped view on this ndarray like + /// [`np.reshape()`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html). + /// + /// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be + /// modified as a result. + /// + /// If reshape without copying is impossible, this function will allocate a new ndarray and copy + /// contents. + /// + /// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`]. + /// * `new_shape` - The target shape to do `np.reshape()`. + #[must_use] + pub fn reshape_or_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + new_ndims: u64, + new_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert_eq!(new_shape.element_type(ctx, generator), self.llvm_usize.into()); + + // TODO: The current criterion for whether to do a full copy or not is by checking + // `is_c_contiguous`, but this is not optimal - there are cases when the ndarray is + // not contiguous but could be reshaped without copying data. Look into how numpy does + // it. + + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + .construct_uninitialized(generator, ctx, None); + dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); + + // Resolve negative indices + let size = self.size(generator, ctx); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_shape = dst_ndarray.shape(); + irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( + generator, + ctx, + size, + dst_ndims, + dst_shape.as_slice_value(ctx, generator), + ); + + gen_if_callback( + generator, + ctx, + |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |generator, ctx| { + // Reshape is possible without copying + dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); + + Ok(()) + }, + |generator, ctx| { + // Reshape is impossible without copying + unsafe { + dst_ndarray.create_data(generator, ctx); + } + dst_ndarray.copy_data_from(generator, ctx, *self); + + Ok(()) + }, + ) + .unwrap(); + + dst_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 276c00c7..db7acaf3 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -5,8 +5,10 @@ use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ - helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails}, - numpy::make_ndarray_ty, + helper::{ + debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + }, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, }; use crate::{ @@ -15,7 +17,7 @@ use crate::{ numpy::*, stmt::exn_constructor, types::ndarray::NDArrayType, - values::{ProxyValue, RangeValue}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -193,7 +195,6 @@ struct BuiltinBuilder<'a> { ndarray_float: Type, ndarray_float_2d: Type, - ndarray_num_ty: Type, float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, @@ -307,7 +308,6 @@ impl<'a> BuiltinBuilder<'a> { ndarray_float, ndarray_float_2d, - ndarray_num_ty, float_or_ndarray_ty, float_or_ndarray_var_map, @@ -1356,20 +1356,41 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), + PrimDef::FunNpReshape => { + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[ + (in_ndarray_ty.ty, "x"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding + ], + Box::new(move |ctx, _, fun, args, generator| { + let ndarray_ty = fun.0.args[0].ty; + let ndarray_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let shape_ty = fun.0.args[1].ty; + let shape_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray_val.into_pointer_value(), None); + + let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); + + // The ndims after reshaping is gotten from the return type of the call. + let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + }), + ) + } _ => unreachable!(), } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index b03b3616..f5349890 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b4df49c9..05408683 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 65a6a8ac..029815aa 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index cfedf1f6..fc8f1aaf 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a77a21c..34e78a2b 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", ] diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index d42f3b93..a82cfbad 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,13 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_4(n: ndarray[float, Literal[4]]): + for x in range(len(n)): + for y in range(len(n[x])): + for z in range(len(n[x][y])): + for w in range(len(n[x][y][z])): + output_float64(n[x][y][z][w]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -197,6 +204,38 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_reshape(): + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x = np_reshape(w, (1, 2, 1, -1)) + y = np_reshape(x, [2, -1]) + z = np_reshape(y, 10) + + output_int32(np_shape(w)[0]) + output_ndarray_float_1(w) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_int32(np_shape(x)[2]) + output_int32(np_shape(x)[3]) + output_ndarray_float_4(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_ndarray_float_1(z) + + x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) + x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + + output_int32(np_shape(x1)[0]) + output_ndarray_int32_1(x1) + + output_int32(np_shape(x2)[0]) + output_int32(np_shape(x2)[1]) + output_ndarray_int32_2(x2) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1448,19 +1487,6 @@ def test_ndarray_transpose(): output_ndarray_float_2(x) output_ndarray_float_2(y) -def test_ndarray_reshape(): - w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - x = np_reshape(w, (1, 2, 1, -1)) - y = np_reshape(x, [2, -1]) - z = np_reshape(y, 10) - - x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) - x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) - - output_ndarray_float_1(w) - output_ndarray_float_2(y) - output_ndarray_float_1(z) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1592,6 +1618,8 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_reshape() + test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar() @@ -1756,7 +1784,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_transpose() - test_ndarray_reshape() test_ndarray_dot() test_ndarray_cholesky() From 43e440d2fda7857f19544d61e2dccee21b5a79eb Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 15:23:41 +0800 Subject: [PATCH 26/46] [core] codegen/ndarray: Reimplement broadcasting Based on 9359ed96: core/ndstrides: implement broadcasting & np_broadcast_to() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/broadcast.hpp | 165 ++++++++++++ .../src/codegen/irrt/ndarray/broadcast.rs | 82 ++++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/types/ndarray/broadcast.rs | 176 +++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 16 ++ .../src/codegen/values/ndarray/broadcast.rs | 248 ++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 2 + nac3core/src/toplevel/builtins.rs | 24 +- nac3core/src/toplevel/helper.rs | 2 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 2 +- nac3standalone/demo/interpret_demo.py | 1 + nac3standalone/demo/src/ndarray.py | 24 ++ 18 files changed, 748 insertions(+), 13 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/broadcast.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/types/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/values/ndarray/broadcast.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index bb7fb3d4..09844d3b 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -10,4 +10,5 @@ #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" -#include "irrt/ndarray/reshape.hpp" \ No newline at end of file +#include "irrt/ndarray/reshape.hpp" +#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..6e54b1ca --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray::broadcast { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) { + if (src_ndims > target_ndims) { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry* shapes, SizeT dst_ndims, SizeT* dst_shape) { + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1 || entry_dim == dst_dim) { + // Do nothing + } else { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + +#ifdef IRRT_DEBUG_ASSERT + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +#endif +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(const NDArray* src_ndarray, NDArray* dst_ndarray) { + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } else { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace ndarray::broadcast +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, + int32_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, + int64_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs new file mode 100644 index 00000000..cb1ecd4c --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -0,0 +1,82 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + types::{ndarray::ShapeEntryType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_broadcast_to`. +/// +/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`. +/// +/// `dst_ndarray` must meet the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. +/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + infer_and_call_function( + ctx, + &name, + None, + &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_broadcast_shapes`. +/// +/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`, +/// writing the result to `dst_shape`. +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + num_shape_entries: IntValue<'ctx>, + shape_entries: ArraySliceValue<'ctx>, + dst_ndims: IntValue<'ctx>, + dst_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(num_shape_entries.get_type(), llvm_usize); + assert!(ShapeEntryType::is_type( + generator, + ctx.ctx, + shape_entries.base_ptr(ctx, generator).get_type() + ) + .is_ok()); + assert_eq!(dst_ndims.get_type(), llvm_usize); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + infer_and_call_function( + ctx, + &name, + None, + &[ + num_shape_entries.into(), + shape_entries.base_ptr(ctx, generator).into(), + dst_ndims.into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index f67566b9..c640042b 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -18,12 +18,14 @@ use crate::codegen::{ }; pub use array::*; pub use basic::*; +pub use broadcast::*; pub use indexing::*; pub use iter::*; pub use reshape::*; mod array; mod basic; +mod broadcast; mod indexing; mod iter; mod reshape; diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs new file mode 100644 index 00000000..5ee28454 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -0,0 +1,176 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::codegen::{ + types::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, + }, + values::{ndarray::ShapeEntryValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ShapeEntryType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ShapeEntryStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> ShapeEntryType<'ctx> { + /// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_ndarray_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ShapeEntryStructFields<'ctx> { + ShapeEntryStructFields::new(ctx, llvm_usize) + } + + /// See [`ShapeEntryStructFields::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> { + Self::fields(ctx, self.llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_ty = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_ty, llvm_usize } + } + + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type Base = PointerType<'ctx>; + type Value = ShapeEntryValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ShapeEntryType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 17bb6adc..316d0f33 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -20,11 +20,13 @@ use crate::{ toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; mod array; +mod broadcast; mod contiguous; pub mod factory; mod indexing; @@ -118,6 +120,20 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + } + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] pub fn new_unsized( diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs new file mode 100644 index 00000000..0c84b05e --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -0,0 +1,248 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue}, +}; +use itertools::Itertools; + +use crate::codegen::{ + irrt, + types::{ + ndarray::{NDArrayType, ShapeEntryType}, + structure::StructField, + ProxyType, + }, + values::{ + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct ShapeEntryValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ShapeEntryValue<'ctx> { + /// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is + /// not an instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + >::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).ndims + } + + /// Stores the number of dimensions into this value. + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.ndims_field().set(ctx, self.value, value, self.name); + } + + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).shape + } + + /// Stores the shape into this value. + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.shape_field().set(ctx, self.value, value, self.name); + } +} + +impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = ShapeEntryType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ShapeEntryValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); + + let broadcast_ndarray = + NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) + .construct_uninitialized(generator, ctx, None); + broadcast_ndarray.copy_shape_from_array( + generator, + ctx, + target_shape.base_ptr(ctx, generator), + ); + + irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); + broadcast_ndarray + } +} + +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Clone)] +pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + + /// The broadcasting shape. + pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, + + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call [`irrt::ndarray::call_nac3_ndarray_broadcast_shapes`]. +fn broadcast_shapes<'ctx, G, Shape>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + + assert!(in_shape_entries + .iter() + .all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into())); + assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into()); + + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false); + let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = unsafe { + shape_entries.ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + None, + ) + }; + let shape_entry = llvm_shape_ty.map_value(pshape_entry, None); + + let in_ndims = llvm_usize.const_int(*in_ndims, false); + shape_entry.store_ndims(ctx, in_ndims); + + shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator)); + } + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false); + irrt::ndarray::call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayType<'ctx> { + /// Broadcast all ndarrays according to + /// [`np.broadcast()`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html) + /// and return a [`BroadcastAllResult`] containing all the information of the result of the + /// broadcast operation. + pub fn broadcast( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[NDArrayValue<'ctx>], + ) -> BroadcastAllResult<'ctx, G> { + assert!(!ndarrays.is_empty()); + assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); + + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); + assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); + let broadcast_shape = ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(), + broadcast_ndims, + None, + ); + let broadcast_shape = TypedArrayLikeAdapter::from( + broadcast_shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| { + ( + ndarray.shape().as_slice_value(ctx, generator), + ndarray.get_type().ndims().unwrap(), + ) + }) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 951792fb..d91f734d 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -20,10 +20,12 @@ use crate::codegen::{ types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod broadcast; mod contiguous; mod indexing; mod nditer; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index db7acaf3..3f71b983 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -373,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1328,7 +1328,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1356,7 +1359,10 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1386,7 +1392,17 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, &shape) + } + + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, &shape) + } + + _ => unreachable!(), + }; Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9313b13b..de90a41b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -60,6 +60,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -253,6 +254,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index f5349890..41b39bb8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 05408683..90408d91 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 029815aa..f0418889 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index fc8f1aaf..72e54e02 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 34e78a2b..a8a534cd 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8f1c54fc..87692114 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1594,7 +1594,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 56c6126d..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index a82cfbad..374bcf73 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_3(n: ndarray[float, Literal[3]]): + for d in range(len(n)): + for r in range(len(n[d])): + for c in range(len(n[d][r])): + output_float64(n[d][r][c]) + def output_ndarray_float_4(n: ndarray[float, Literal[4]]): for x in range(len(n)): for y in range(len(n[x])): @@ -236,6 +242,23 @@ def test_ndarray_reshape(): output_int32(np_shape(x2)[1]) output_ndarray_int32_2(x2) +def test_ndarray_broadcast_to(): + xs = np_array([1.0, 2.0, 3.0]) + ys = np_broadcast_to(xs, (1, 3)) + zs = np_broadcast_to(ys, (2, 4, 3)) + + output_int32(np_shape(xs)[0]) + output_ndarray_float_1(xs) + + output_int32(np_shape(ys)[0]) + output_int32(np_shape(ys)[1]) + output_ndarray_float_2(ys) + + output_int32(np_shape(zs)[0]) + output_int32(np_shape(zs)[1]) + output_int32(np_shape(zs)[2]) + output_ndarray_float_3(zs) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1619,6 +1642,7 @@ def run() -> int32: test_ndarray_nd_idx() test_ndarray_reshape() + test_ndarray_broadcast_to() test_ndarray_add() test_ndarray_add_broadcast() From 7375983e0c60d04651424aa3cde509b52a20a2d3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:22:59 +0800 Subject: [PATCH 27/46] [core] codegen/ndarray: Implement np_transpose without axes argument Based on 052b67c8: core/ndstrides: implement np_transpose() (no axes argument) The IRRT implementation knows how to handle axes. But the argument is not in NAC3 yet. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/transpose.hpp | 143 ++++++++++++++++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/irrt/ndarray/transpose.rs | 48 ++++++ nac3core/src/codegen/numpy.rs | 108 ------------- nac3core/src/codegen/values/ndarray/view.rs | 50 +++++- nac3core/src/toplevel/builtins.rs | 7 +- nac3standalone/demo/src/ndarray.py | 27 ++-- 8 files changed, 267 insertions(+), 121 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/transpose.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/transpose.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 09844d3b..39ddba67 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -11,4 +11,5 @@ #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" -#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/transpose.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 00000000..662ceb1e --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace { +namespace ndarray::transpose { +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template +void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) { + if (ndims != num_axes) { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == -1) { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray* src_ndarray, NDArray* dst_ndarray, SizeT num_axes, const SizeT* axes) { + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } else { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace ndarray::transpose +} // namespace + +extern "C" { +using namespace ndarray::transpose; +void __nac3_ndarray_transpose(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int32_t num_axes, + const int32_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} + +void __nac3_ndarray_transpose64(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int64_t num_axes, + const int64_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index c640042b..ba22568e 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -22,6 +22,7 @@ pub use broadcast::*; pub use indexing::*; pub use iter::*; pub use reshape::*; +pub use transpose::*; mod array; mod basic; @@ -29,6 +30,7 @@ mod broadcast; mod indexing; mod iter; mod reshape; +mod transpose; /// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs new file mode 100644 index 00000000..57661fa7 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -0,0 +1,48 @@ +use inkwell::{values::IntValue, AddressSpace}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_transpose`. +/// +/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`. +/// +/// `dst_ndarray` must fulfill the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`. +/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, + axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); + assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + infer_and_call_function( + ctx, + &name, + None, + &[ + src_ndarray.as_base_value().into(), + dst_ndarray.as_base_value().into(), + axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), + axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { + axes.base_ptr(ctx, generator) + }) + .into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ce4e2059..fdbb716b 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1307,114 +1307,6 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (x1_ty, x1): (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = n1.size(generator, ctx); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 8334d379..450f7444 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,6 +1,6 @@ use std::iter::{once, repeat_n}; -use inkwell::values::IntValue; +use inkwell::values::{IntValue, PointerValue}; use itertools::Itertools; use crate::codegen::{ @@ -9,7 +9,7 @@ use crate::codegen::{ types::ndarray::NDArrayType, values::{ ndarray::{NDArrayValue, RustNDIndex}, - ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -108,4 +108,50 @@ impl<'ctx> NDArrayValue<'ctx> { dst_ndarray } + + /// Create a transposed view on this ndarray like + /// [`np.transpose(, = None)`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html). + /// + /// * `axes` - If specified, should be an array of the permutation (negative indices are + /// **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>, + ) -> Self { + assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); + assert!( + axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) + ); + + // Define models + let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); + + let axes = if let Some(axes) = axes { + let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false); + + // `axes = nullptr` if `axes` is unspecified. + let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None); + + Some(TypedArrayLikeAdapter::from( + axes, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + )) + } else { + None + }; + + irrt::ndarray::call_nac3_ndarray_transpose( + generator, + ctx, + *self, + transposed_ndarray, + axes.as_ref(), + ); + + transposed_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 3f71b983..538961a6 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1349,7 +1349,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg_val.into_pointer_value(), None); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.as_base_value().into())) }), ), diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 374bcf73..170ac14c 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -210,6 +210,23 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_transpose(): + x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) + y = np_transpose(x) + z = np_transpose(y) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_ndarray_float_2(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_int32(np_shape(z)[1]) + output_ndarray_float_2(z) + def test_ndarray_reshape(): w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) x = np_reshape(w, (1, 2, 1, -1)) @@ -1502,14 +1519,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) -def test_ndarray_transpose(): - x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) - y = np_transpose(x) - z = np_transpose(y) - - output_ndarray_float_2(x) - output_ndarray_float_2(y) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1641,6 +1650,7 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_transpose() test_ndarray_reshape() test_ndarray_broadcast_to() @@ -1807,7 +1817,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() - test_ndarray_transpose() test_ndarray_dot() test_ndarray_cholesky() From dcde1d9c87bc7629236c0f059eb250869871f896 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:32:34 +0800 Subject: [PATCH 28/46] [core] codegen/values/ndarray: Add more ScalarOrNDArray utils Based on f731e604: core/ndstrides: add more ScalarOrNDArray and NDArrayObject utils --- nac3core/src/codegen/values/ndarray/mod.rs | 121 ++++++++++++++++++--- 1 file changed, 107 insertions(+), 14 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d91f734d..6c8c9aab 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -12,13 +12,16 @@ use super::{ TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; -use crate::codegen::{ - irrt, - llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, - stmt::gen_for_callback_incrementing, - type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + irrt, + llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, + stmt::gen_for_callback_incrementing, + type_aligned_alloca, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, }; pub use broadcast::*; pub use contiguous::*; @@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> { self.ndims.map(|ndims| ndims == 0) } - /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. - /// Otherwise, do nothing and return the ndarray itself. - // TODO: Rename to get_unsized_element - pub fn split_unsized( + /// Returns the element present in this `ndarray` if this is unsized. + pub fn get_unsized_element( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ) -> ScalarOrNDArray<'ctx> { - let Some(is_unsized) = self.is_unsized() else { todo!() }; + ) -> Option> { + let Some(is_unsized) = self.is_unsized() else { + panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + }; if is_unsized { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; - ScalarOrNDArray::Scalar(value) + Some(value) + } else { + None + } + } + + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { + ScalarOrNDArray::Scalar(unsized_elem) } else { ScalarOrNDArray::NDArray(*self) } @@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayValue<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn from_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + ScalarOrNDArray::NDArray(ndarray) + } + + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with + /// [`NDArrayType::construct_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayValue<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => { + NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None) + } + } + } + + /// Get the dtype of the ndarray created if this were called with + /// [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.get_type(), + } + } } From 2dc5e79a23ea630f88b1bef3faa43f52e2e807ef Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:44:57 +0800 Subject: [PATCH 29/46] [core] codegen/ndarray: Implement subscript assignment Based on 5bed394e: core/ndstrides: implement subscript assignment Overlapping is not handled. Currently it has undefined behavior. --- nac3core/src/codegen/stmt.rs | 55 ++++++++++++++++++++++++++++-- nac3standalone/demo/src/ndarray.py | 33 ++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 3595528f..edebb4f0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -16,7 +16,11 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, + types::ndarray::NDArrayType, + values::{ + ndarray::{RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -411,7 +415,54 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + + // Process key + let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + + // Reference code: + // ```python + // target = target[key] + // value = np.asarray(value) + // + // shape = np.broadcast_shape((target, value)) + // + // target = np.broadcast_to(target, shape) + // value = np.broadcast_to(value, shape) + // + // # ...and finally copy 1-1 from value to target. + // ``` + + let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) + .map_value(target.into_pointer_value(), None); + let target = target.index(generator, ctx, &key); + + let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) + .to_ndarray(generator, ctx); + + let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] + .iter() + .filter_map(|ndims| *ndims) + .max(); + let broadcast_result = NDArrayType::new( + generator, + ctx.ctx, + value.get_type().element_type(), + broadcast_ndims, + ) + .broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 170ac14c..b668860f 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -276,6 +276,38 @@ def test_ndarray_broadcast_to(): output_int32(np_shape(zs)[2]) output_ndarray_float_3(zs) +def test_ndarray_subscript_assignment(): + xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]]) + + xs[0, 0] = 99.0 + output_ndarray_float_2(xs) + + xs[0] = 100.0 + output_ndarray_float_2(xs) + + xs[:, ::2] = 101.0 + output_ndarray_float_2(xs) + + xs[1:, 0] = 102.0 + output_ndarray_float_2(xs) + + xs[0] = np_array([-1.0, -2.0, -3.0, -4.0]) + output_ndarray_float_2(xs) + + xs[:] = np_array([-5.0, -6.0, -7.0, -8.0]) + output_ndarray_float_2(xs) + + # Test assignment with memory sharing + ys1 = np_reshape(xs, (2, 4)) + ys2 = np_transpose(ys1) + ys3 = ys2[::-1, 0] + ys3[0] = -999.0 + + output_ndarray_float_2(xs) + output_ndarray_float_2(ys1) + output_ndarray_float_2(ys2) + output_ndarray_float_1(ys3) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1653,6 +1685,7 @@ def run() -> int32: test_ndarray_transpose() test_ndarray_reshape() test_ndarray_broadcast_to() + test_ndarray_subscript_assignment() test_ndarray_add() test_ndarray_add_broadcast() From e6dab25a570643911878d244be5e07b7c162dea3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 17:15:51 +0800 Subject: [PATCH 30/46] [core] codegen/ndarray: Add NDArrayOut, broadcast_map, map Based on fbfc0b29: core/ndstrides: add NDArrayOut, broadcast_map and map --- nac3core/src/codegen/types/ndarray/map.rs | 187 +++++++++++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/values/ndarray/map.rs | 69 ++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 45 +++++ 4 files changed, 302 insertions(+) create mode 100644 nac3core/src/codegen/types/ndarray/map.rs create mode 100644 nac3core/src/codegen/values/ndarray/map.rs diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs new file mode 100644 index 00000000..0d63b22f --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -0,0 +1,187 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; +use itertools::Itertools; + +use crate::codegen::{ + stmt::gen_for_callback, + types::{ + ndarray::{NDArrayType, NDIterType}, + ProxyType, + }, + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ArrayLikeValue, ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayType<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` + /// elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when + /// iterating through the input `ndarrays` after broadcasting. The output of `mapping` is the + /// result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing + /// ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[NDArrayValue<'ctx>], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result<>::Value, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = self.broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + .construct_uninitialized(generator, ctx, None); + result_ndarray.copy_shape_from_array( + generator, + ctx, + broadcast_result.shape.base_ptr(ctx, generator), + ); + unsafe { + result_ndarray.create_data(generator, ctx); + } + result_ndarray + } + + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| { + NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) + }) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_element()`. + // `in_nditers`' `has_element()`s should return the same value. + Ok(out_nditer.has_element(generator, ctx)) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + let in_scalars = + in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a + /// scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars + /// or ndarrays or a mix of them as their inputs and produces either an ndarray with broadcast, + /// or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following + /// behaviors: + /// + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is + /// converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> + /// ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a + /// [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be + /// 'as-ndarray'-ed into ndarrays, then all inputs (now all ndarrays) will be passed to + /// [`NDArrayValue::broadcasting_starmap`] and **create** a new ndarray with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: BasicTypeEnum<'ctx>, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = + inputs.iter().map(BasicValueEnum::<'ctx>::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().copied().collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(value)) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ret_dtype, + &inputs.iter().map(NDArrayValue::get_type).collect_vec(), + ) + .broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 316d0f33..43712be1 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -30,6 +30,7 @@ mod broadcast; mod contiguous; pub mod factory; mod indexing; +mod map; mod nditer; /// Proxy type for a `ndarray` type in LLVM. diff --git a/nac3core/src/codegen/values/ndarray/map.rs b/nac3core/src/codegen/values/ndarray/map.rs new file mode 100644 index 00000000..72d1bf9d --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/map.rs @@ -0,0 +1,69 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; + +use crate::codegen::{ + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + self.get_type().broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a + /// [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new + /// ndarray of the results will be created and returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: BasicTypeEnum<'ctx>, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 6c8c9aab..89f88e74 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -31,6 +31,7 @@ pub use nditer::*; mod broadcast; mod contiguous; mod indexing; +mod map; mod nditer; pub mod shape; mod view; @@ -540,6 +541,26 @@ impl<'ctx> NDArrayValue<'ctx> { ScalarOrNDArray::NDArray(*self) } } + + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_shape: impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) { + let ndarray_shape = self.shape(); + let output_shape = out_shape; + + irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + &ndarray_shape, + &output_shape, + ); + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { @@ -1081,3 +1102,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } } + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a +/// function will create a new ndarray and store the result in it. +#[derive(Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: BasicTypeEnum<'ctx> }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayValue<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +} From 6cbba8fdde439afad1cfab65cb3267ee483e043a Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 11:24:28 +0800 Subject: [PATCH 31/46] [core] codegen: Reimplement builtin funcs to support strided ndarrays Based on 7f3c4530: core/ndstrides: update builtin_fns to use ndarray with strides --- nac3core/src/codegen/builtin_fns.rs | 959 +++++++++++----------------- 1 file changed, 368 insertions(+), 591 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 54650ab3..7c8ad7a8 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -11,19 +11,16 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - numpy, - numpy::ndarray_elementwise_unaryop_impl, - stmt::gen_for_callback_incrementing, types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; use crate::{ toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, }, typecheck::typedef::{Type, TypeEnum}, @@ -129,18 +126,18 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_int32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -189,18 +186,18 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_int64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -265,18 +262,18 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_uint32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -330,18 +327,18 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_uint64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -355,7 +352,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { @@ -394,20 +390,19 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndims = extract_ndims(&ctx.unifier, ndims); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_float(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -440,18 +435,20 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty.into() }, + |generator, ctx, scalar| { + call_round(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -477,18 +474,18 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_numpy_round(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -539,22 +536,21 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (elem_ty, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -591,18 +587,20 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -639,18 +637,20 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -741,42 +741,37 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -861,23 +856,26 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => codegen_unreachable!(ctx), } } + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); - let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = n.size(generator, ctx); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None); + let llvm_dtype = ndarray.get_type().element_type(); + + let zero = llvm_usize.const_zero(); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx + let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") .unwrap(); ctx.make_assert( generator, - n_sz_eqz, + size_nez, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -885,54 +883,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = - generator.gen_var_alloc(ctx, llvm_ndarray_ty.element_type(), None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, llvm_dtype, None)?; + let extremum_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = unsafe { ndarray.data().get_unchecked(ctx, generator, &zero, None) }; + ctx.builder.build_store(extremum, first_value).unwrap(); + ctx.builder.build_store(extremum_idx, zero).unwrap(); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { - n.data().get_unchecked( - ctx, - generator, - &ctx.builder - .build_int_truncate_or_bit_cast(idx, llvm_usize, "") - .unwrap(), - None, - ) - }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |_, ctx, _, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = ctx + .builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); - let result = match fn_name { + let curr_value = nditer.get_scalar(ctx); + let curr_idx = nditer.get_index(ctx); + + let new_extremum = match fn_name { "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_min(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_max(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } _ => codegen_unreachable!(ctx), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), @@ -942,24 +929,35 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + ctx.builder.build_store(extremum_idx, new_extremum_idx).unwrap(); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => ctx + .builder + .build_int_s_extend_or_bit_cast( + ctx.builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + ctx.ctx.i64_type(), + "", + ) + .unwrap() + .into(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => codegen_unreachable!(ctx), } } @@ -1006,42 +1004,37 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1074,39 +1067,20 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = ScalarOrNDArray::from_value(generator, ctx, (arg_ty, arg_val)); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(x, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; + let ret_ty = get_ret_elem_type(ctx, dtype); + let llvm_ret_ty = ctx.get_llvm_type(generator, ret_ty); + let result = arg.map(generator, ctx, llvm_ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; - Ok(result) + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1431,59 +1405,29 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. @@ -1495,59 +1439,29 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. @@ -1559,59 +1473,29 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. @@ -1623,59 +1507,29 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. @@ -1687,48 +1541,31 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert_eq!(x1.get_dtype(), ctx.ctx.f64_type().into()); + debug_assert_eq!(x2.get_dtype(), ctx.ctx.i32_type().into()); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. @@ -1740,59 +1577,29 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1804,59 +1611,29 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_nextafter"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function From 59f19e29df1f9765ff6b83cffd3a9ca5d450b145 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:25:35 +0800 Subject: [PATCH 32/46] [core] codegen: Reimplement ndarray binop Based on 9e40c834: core/ndstrides: implement binop --- nac3core/src/codegen/expr.rs | 163 +++++++++++++++++------------------ 1 file changed, 80 insertions(+), 83 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8d7f8e35..8225df7c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,14 +34,19 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{ + helper::{arraylike_flatten_element_type, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, TopLevelDef, + }, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -1526,98 +1531,90 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); + let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if is_ndarray1 && is_ndarray2 { + if op.base == Operator::MatMult { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); - let left_val = NDArrayType::from_unifier_type(generator, ctx, ty1) - .map_value(left_val.into_pointer_value(), None); - let right_val = NDArrayType::from_unifier_type(generator, ctx, ty2) - .map_value(right_val.into_pointer_value(), None); - - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (ty1, left_val.as_base_value().into(), false), - (ty2, right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; - - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = - NDArrayType::from_unifier_type(generator, ctx, if is_ndarray1 { ty1 } else { ty2 }) - .map_value( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( + // MatMult is the only binop which is not an elementwise op + let result = numpy::ndarray_matmul_2d( generator, ctx, - ndarray_dtype, + ndarray_dtype1, match op.variant { BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (ty1, left_val, !is_ndarray1), - (ty2, right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype), lhs), - op, - (&Some(ndarray_dtype), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + BinopVariant::AugAssign => Some(left), }, + left, + right, )?; - Ok(Some(res.as_base_value().into())) + Ok(Some(result.as_base_value().into())) + } else { + // For other operations, they are all elementwise operations. + + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // If this is an augmented assignment. + // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap(generator, ctx, &[left, right], out, |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(ty1_dtype), left_value), + op, + (&Some(ty2_dtype), right_value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, common_dtype)?; + + Ok(result) + }) + .unwrap(); + Ok(Some(result.as_base_value().into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); From a2f1b25fd881ca34d34ad3304f67fc4b95e54aca Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:37:17 +0800 Subject: [PATCH 33/46] [core] codegen: Reimplement ndarray unary op Based on bb992704: core/ndstrides: implement unary op --- nac3core/src/codegen/expr.rs | 18 +++++------ nac3core/src/codegen/numpy.rs | 59 ----------------------------------- 2 files changed, 8 insertions(+), 69 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8225df7c..232c68c2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1777,10 +1777,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) + .map_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1798,20 +1798,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? + .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; - res.as_base_value().into() + mapped_ndarray.as_base_value().into() } else { unimplemented!() })) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index fdbb716b..d02103ab 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -195,28 +195,6 @@ where }) } -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of /// the target `ndarray`. fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( @@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - /// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output From ebbadc2d742ed1c7e8901e94fbf07fc1721c5401 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:46:24 +0800 Subject: [PATCH 34/46] [core] codegen: Reimplement ndarray cmpop Based on 56cccce1: core/ndstrides: implement cmpop --- nac3core/irrt/irrt/ndarray.hpp | 70 ------- nac3core/src/codegen/expr.rs | 109 ++++------ nac3core/src/codegen/irrt/ndarray/mod.rs | 170 +--------------- nac3core/src/codegen/numpy.rs | 248 +---------------------- 4 files changed, 43 insertions(+), 554 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 7fc9a63b..534f18d6 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n stride *= dims[i]; } } - -template -void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, - SizeT src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} } // namespace extern "C" { @@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32 void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } - -void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} } \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 232c68c2..8a002bb3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1852,83 +1852,52 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty); + let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left)) + .to_ndarray(generator, ctx); + let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right)) + .to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ctx.ctx.i8_type().into(), + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty) - .map_value(lhs.into_pointer_value(), None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, left_val.as_base_value().into(), false), - (right_ty, rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty_dtype), left_scalar), + &[op], + &[(Some(right_ty_dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, lhs, !is_ndarray1), - (right_ty, rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.as_base_value().into())); } } diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index ba22568e..151795c5 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,18 +1,15 @@ use inkwell::{ types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, IntPredicate, + AddressSpace, }; use itertools::Either; use super::get_usize_dependent_function_name; use crate::codegen::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, values::{ ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( |_, _, v| v.into(), ) } - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); - let ndarray_calc_broadcast_fn = - ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.shape().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.shape().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into()) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.shape().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d02103ab..9fe5a972 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -10,10 +10,7 @@ use super::{ expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, - ndarray::{ - call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, - call_ndarray_calc_nd_indices, call_ndarray_calc_size, - }, + ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, }, llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, @@ -21,7 +18,7 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, @@ -195,152 +192,6 @@ where }) } -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - (lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - (rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Returns the element of an ndarray indexed by the given indices, performing int-promotion on - // `indices` where necessary. - // - // Required for compatibility with `NDArrayType::get_unchecked`. - let get_data_by_indices_compat = - |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { - let llvm_usize = generator.get_size_type(ctx.ctx); - - // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT - let stackptr = llvm_intrinsics::call_stacksave(ctx, None); - let indices = if llvm_usize == ctx.ctx.i32_type() { - indices - } else { - let indices_usize = TypedArrayLikeAdapter::>::from( - ArraySliceValue::from_ptr_val( - ctx.builder - .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") - .unwrap(), - indices.size(ctx, generator), - None, - ), - |_, _, val| val.into_int_value(), - |_, _, val| val.into(), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (indices.size(ctx, generator), false), - |generator, ctx, _, i| { - let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; - let idx = ctx - .builder - .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") - .unwrap(); - unsafe { - indices_usize.set_typed_unchecked(ctx, generator, &i, idx); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - indices_usize - }; - - let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; - - llvm_intrinsics::call_stackrestore(ctx, stackptr); - - elem - }; - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (Type, BasicValueEnum<'ctx>, bool), - rhs: (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let (lhs_ty, lhs_val, lhs_scalar) = lhs; - let (rhs_ty, rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayType::from_unifier_type( - generator, - ctx, - if lhs_scalar { rhs_ty } else { lhs_ty }, - ) - .map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// /// * `elem_ty` - The element type of the `NDArray`. From 66b8a5e01d73d84f1f1312f6915e3977ec7cc24c Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:52 +0800 Subject: [PATCH 35/46] [core] codegen/ndarray: Reimplement matmul Based on 73c2203b: core/ndstrides: implement general matmul --- nac3core/irrt/irrt.cpp | 4 +- nac3core/irrt/irrt/int_types.hpp | 2 - nac3core/irrt/irrt/ndarray.hpp | 50 -- nac3core/irrt/irrt/ndarray/matmul.hpp | 98 +++ nac3core/src/codegen/expr.rs | 67 +- nac3core/src/codegen/irrt/ndarray/matmul.rs | 66 ++ nac3core/src/codegen/irrt/ndarray/mod.rs | 131 +--- nac3core/src/codegen/numpy.rs | 727 +----------------- nac3core/src/codegen/types/ndarray/factory.rs | 2 +- nac3core/src/codegen/values/ndarray/matmul.rs | 334 ++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 1 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 84 +- 17 files changed, 585 insertions(+), 995 deletions(-) delete mode 100644 nac3core/irrt/irrt/ndarray.hpp create mode 100644 nac3core/irrt/irrt/ndarray/matmul.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/matmul.rs create mode 100644 nac3core/src/codegen/values/ndarray/matmul.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 39ddba67..87dcb428 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,7 +1,6 @@ #include "irrt/exception.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" -#include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" #include "irrt/string.hpp" @@ -12,4 +11,5 @@ #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/broadcast.hpp" -#include "irrt/ndarray/transpose.hpp" \ No newline at end of file +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index ed8a48b8..17ccf604 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64); #endif -// NDArray indices are always `uint32_t`. -using NDIndexInt = uint32_t; // The type of an index or a value describing the length of a range/slice is always `int32_t`. using SliceIndex = int32_t; diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp deleted file mode 100644 index 534f18d6..00000000 --- a/nac3core/irrt/irrt/ndarray.hpp +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "irrt/int_types.hpp" - -// TODO: To be deleted since NDArray with strides is done. - -namespace { -template -SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} -} // namespace - -extern "C" { -uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t -__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} -} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..b0fd4d86 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray::matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace ndarray::matmul +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8a002bb3..4b83e63e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -27,7 +27,7 @@ use super::{ call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -1534,26 +1534,35 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if op.base == Operator::MatMult { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // Augmented assignment - `left` has to be an ndarray. If it were a scalar then NAC3 + // simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + if op.base == Operator::MatMult { let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); - - // MatMult is the only binop which is not an elementwise op - let result = numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left), - }, - left, - right, - )?; - - Ok(Some(result.as_base_value().into())) + let result = left + .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) + .split_unsized(generator, ctx); + Ok(Some(result.to_basic_value_enum().into())) } else { // For other operations, they are all elementwise operations. @@ -1565,28 +1574,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // For all cases, the scalar operand is promoted to an ndarray, // the two are then broadcasted, and starmapped through. - let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); - let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); - - // Inhomogeneous binary operations are not supported. - assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); - - let common_dtype = ty1_dtype; - let llvm_common_dtype = left.get_dtype(); - - let out = match op.variant { - BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - BinopVariant::AugAssign => { - // If this is an augmented assignment. - // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. - if let ScalarOrNDArray::NDArray(out_ndarray) = left { - NDArrayOut::WriteToNDArray { ndarray: out_ndarray } - } else { - panic!("left must be an ndarray") - } - } - }; - let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs new file mode 100644 index 00000000..551cb7c7 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -0,0 +1,66 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, + values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`. +/// +/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of +/// `a @ b`. +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + final_ndims: IntValue<'ctx>, + new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!( + BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + + infer_and_call_function( + ctx, + &name, + None, + &[ + a_shape.size(ctx, generator).into(), + a_shape.base_ptr(ctx, generator).into(), + b_shape.size(ctx, generator).into(), + b_shape.base_ptr(ctx, generator).into(), + final_ndims.into(), + new_a_shape.base_ptr(ctx, generator).into(), + new_b_shape.base_ptr(ctx, generator).into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 151795c5..b1530685 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,23 +1,9 @@ -use inkwell::{ - types::BasicTypeEnum, - values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, -}; -use itertools::Either; - -use super::get_usize_dependent_function_name; -use crate::codegen::{ - values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAdapter, - }, - CodeGenContext, CodeGenerator, -}; pub use array::*; pub use basic::*; pub use broadcast::*; pub use indexing::*; pub use iter::*; +pub use matmul::*; pub use reshape::*; pub use transpose::*; @@ -26,119 +12,6 @@ mod basic; mod broadcast; mod indexing; mod iter; +mod matmul; mod reshape; mod transpose; - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns a -/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); - assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); - assert_eq!( - BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - - let ndarray_calc_size_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The `llvm_usize` index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert_eq!(index.get_type(), llvm_usize); - - let ndarray_calc_nd_indices_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9fe5a972..d46a6119 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,736 +1,23 @@ use inkwell::{ - types::BasicType, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - IntPredicate, OptimizationLevel, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use super::{ - expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, - ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, - }, - llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{factory::ndarray_zero_value, NDArrayType}, - values::{ - ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, - TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, - UntypedArrayLikeMutator, - }, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, num_dims, None); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = ndarray.size(generator, ctx); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); - - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let src_data_offset = ctx - .builder - .build_int_mul( - src_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - let dst_data_offset = ctx - .builder - .build_int_mul( - dst_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = - if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.shape().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { ndarray.create_data(generator, ctx) }; - - ndarray - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - }; - let rhs_dim0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_usize.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 300167f7..2d0dca76 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -12,7 +12,7 @@ use crate::{ }; /// Get the zero value in `np.zeros()` of a `dtype`. -pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs new file mode 100644 index 00000000..88a94394 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -0,0 +1,334 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; + +use super::{NDArrayOut, NDArrayValue, RustNDIndex}; +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, + irrt, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ + ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::arraylike_flatten_element_type, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), + (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), +) -> NDArrayValue<'ctx> { + assert!( + in_a.ndims.is_some_and(|ndims| ndims >= 2), + "in_a (which is {:?}) must be compile-time known and >= 2", + in_a.ndims + ); + assert!( + in_b.ndims.is_some_and(|ndims| ndims >= 2), + "in_b (which is {:?}) must be compile-time known and >= 2", + in_b.ndims + ); + + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims = llvm_usize.const_int(ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_a.shape().base_ptr(ctx, generator), + in_lhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_b.shape().base_ptr(ctx, generator), + in_rhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let dst_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Matmul dimension compatibility is checked here. + irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + &in_lhs_shape, + &in_rhs_shape, + ndims, + &lhs_shape, + &rhs_shape, + &dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); + + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, None); + dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); + unsafe { + dst.create_data(generator, ctx); + } + + (lhs, rhs, dst) + }; + + let len = unsafe { + lhs.shape().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(ndims_int - 1, false), + None, + ) + }; + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices::(); + let i = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None) + }; + let j = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None) + }; + + let num_0 = llvm_usize.const_int(0, false); + let num_1 = llvm_usize.const_int(1, false); + + gen_for_callback_incrementing( + generator, + ctx, + None, + num_0, + (len, false), + |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + k.into(), + ); + } + let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) }; + + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + k.into(), + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) }; + + // Restore `indices`. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs_dtype), a_ik), + Binop::normal(Operator::Mult), + (&Some(rhs_dtype), b_kj), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }, + num_1, + ) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). + /// + /// This function always return an [`NDArrayValue`]. You may want to use + /// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + #[must_use] + pub fn matmul( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + self_ty: Type, + (other_ty, other): (Type, Self), + (out_dtype, out): (Type, NDArrayOut<'ctx>), + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!( + self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), + "np.matmul disallows scalar input" + ); + + // If both arguments are 2-D they are multiplied like conventional matrices. + // + // If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the + // last two indices and broadcast accordingly. + // + // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its + // dimensions. After matrix multiplication the prepended 1 is removed. + // + // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its + // dimensions. After matrix multiplication the appended 1 is removed. + + let new_a = if self.ndims.unwrap() == 1 { + // Prepend 1 to its dimensions + self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + *self + }; + + let new_b = if other.ndims.unwrap() == 1 { + // Append 1 to its dimensions + other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + other + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = + matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b)); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = ctx.ctx.i32_type().const_zero(); + + if self.ndims.unwrap() == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if other.ndims.unwrap() == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.shape(); + out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 89f88e74..707c79a2 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -32,6 +32,7 @@ mod broadcast; mod contiguous; mod indexing; mod map; +mod matmul; mod nditer; pub mod shape; mod view; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 41b39bb8..4332b474 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 90408d91..60e0c194 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index f0418889..46601817 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 72e54e02..da58d121 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index a8a534cd..8f384fa1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f03..40bbdeab 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; use super::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, }; @@ -175,19 +175,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +187,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +530,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); From 3ac1083734ff093c77244cc675efb2b9151a56ae Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:32:18 +0800 Subject: [PATCH 36/46] [core] codegen: Reimplement np_dot() for scalars and 1D Based on 693b7f37: core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 136 +++++++++++++++++------------- nac3core/src/toplevel/builtins.rs | 4 +- 2 files changed, 79 insertions(+), 61 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d46a6119..e5a893c9 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,14 +7,18 @@ use nac3parser::ast::StrRef; use super::{ macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, + stmt::gen_for_callback, + types::ndarray::{NDArrayType, NDIterType}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{arraylike_flatten_element_type, extract_ndims}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::typedef::{FunSignature, Type}, }; @@ -300,89 +304,101 @@ pub fn gen_ndarray_fill<'ctx>( pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); + let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = n1.size(generator, ctx); - let n2_sz = n2.size(generator, ctx); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = + ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size), Some(b_size), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx)) + }, + |_, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(ctx); + let b_scalar = b_iter.get_scalar(ctx); - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + _ => codegen_unreachable!( ctx, "{FN_NAME}() not supported for '{}'", diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 538961a6..600276b7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1935,10 +1935,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), From 12fddc3533dcc4c88bac0dfd91a46cf1930723b6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:48:00 +0800 Subject: [PATCH 37/46] [core] codegen/ndarray: Make ndims non-optional Now that everything is ported to use strided impl, dynamic-ndim ndarray instances do not exist anymore. --- nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 2 +- nac3core/src/codegen/builtin_fns.rs | 28 +++---- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 23 +++--- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/codegen/test.rs | 2 +- nac3core/src/codegen/types/ndarray/array.rs | 14 ++-- nac3core/src/codegen/types/ndarray/map.rs | 2 +- nac3core/src/codegen/types/ndarray/mod.rs | 54 ++++--------- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- .../src/codegen/values/ndarray/broadcast.rs | 17 ++--- .../src/codegen/values/ndarray/contiguous.rs | 6 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 33 +++----- nac3core/src/codegen/values/ndarray/mod.rs | 76 +++---------------- nac3core/src/codegen/values/ndarray/view.rs | 11 +-- 17 files changed, 94 insertions(+), 201 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 653f41a3..156ba23e 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -464,7 +464,7 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) .map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -597,7 +597,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 6507dc20..8e9cd10c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1107,7 +1107,7 @@ impl InnerResolver { self.global_value_ids.write().insert(id, obj.into()); } - let ndims = llvm_ndarray.ndims().unwrap(); + let ndims = llvm_ndarray.ndims(); // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 7c8ad7a8..9d368070 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1652,7 +1652,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1694,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1746,8 +1746,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1796,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,7 +1838,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_dyn_shape(generator, ctx, &[d0, d1], None); unsafe { out.create_data(generator, ctx) }; @@ -1880,7 +1880,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1924,7 +1924,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None); + let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, ndims, llvm_usize, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1940,7 +1940,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,7 +1979,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)) + let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) .construct_const_shape(generator, ctx, &[1], None); unsafe { det.create_data(generator, ctx) }; @@ -2008,13 +2008,13 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2053,13 +2053,13 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2ce3c9ab..4e83d530 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -520,7 +520,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into() + NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e5a893c9..6c16be9f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,8 +120,13 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) - .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + generator, + context, + &shape, + fill_value_arg, + None, + ); Ok(ndarray.as_base_value()) } @@ -218,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -246,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -315,8 +320,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. - assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); - assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert_eq!(a.get_type().ndims(), 1); + assert_eq!(b.get_type().ndims(), 1); let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); // Check shapes. diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index edebb4f0..2a3bd066 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -447,10 +447,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) .to_ndarray(generator, ctx); - let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] - .iter() - .filter_map(|ndims| *ndims) - .max(); + let broadcast_ndims = + [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( generator, ctx.ctx, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 97bd3f09..2701e138 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -464,6 +464,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None); + let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 87cd002a..0f30f0eb 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -41,7 +41,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); let list_value = list.as_i8_list(generator, ctx); @@ -61,7 +61,7 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) .construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -93,12 +93,12 @@ impl<'ctx> NDArrayType<'ctx> { if ndims == 1 { // `list` is not nested assert_eq!(ndims, 1); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert!(self.ndims >= ndims); assert_eq!(dtype, self.dtype); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + let ndarray = Self::new(generator, ctx.ctx, dtype, 1) .construct_uninitialized(generator, ctx, name); // Set data @@ -170,7 +170,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. @@ -183,9 +183,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { assert_eq!(ndarray.get_type().dtype, self.dtype); - assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self - .ndims - .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert!(self.ndims >= ndarray.get_type().ndims); assert_eq!(copy.get_type(), ctx.ctx.bool_type()); let ndarray_val = gen_if_else_expr_callback( diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 0d63b22f..bf82b4da 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -47,7 +47,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 43712be1..353ace33 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -38,7 +38,7 @@ mod nditer; pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, } @@ -113,7 +113,7 @@ impl<'ctx> NDArrayType<'ctx> { generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, ) -> Self { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); @@ -132,7 +132,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. @@ -145,7 +145,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize } + NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -164,7 +164,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: Self::llvm_type(ctx.ctx, llvm_usize), dtype: llvm_dtype, - ndims: Some(ndims), + ndims, llvm_usize, } } @@ -174,7 +174,7 @@ impl<'ctx> NDArrayType<'ctx> { pub fn from_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); @@ -196,7 +196,7 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the number of dimensions of this `ndarray` type. #[must_use] - pub fn ndims(&self) -> Option { + pub fn ndims(&self) -> u64 { self.ndims } @@ -286,35 +286,7 @@ impl<'ctx> NDArrayType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - - let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else { - unreachable!() - }; - - self.construct_impl(generator, ctx, ndims, name) - } - - /// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`. - /// - /// `shape` and `strides` will be automatically allocated onto the stack. - /// - /// The returned ndarray's content will be: - /// - `data`: uninitialized. - /// - `itemsize`: set to the size of `dtype`. - /// - `ndims`: set to the value of `ndims`. - /// - `shape`: allocated with an array of length `ndims` with uninitialized values. - /// - `strides`: allocated with an array of length `ndims` with uninitialized values. - #[deprecated = "Prefer construct_uninitialized or construct_*_shape."] - #[must_use] - pub fn construct_dyn_ndims( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> >::Value { - assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)"); + let ndims = self.llvm_usize.const_int(self.ndims, false); self.construct_impl(generator, ctx, ndims, name) } @@ -330,9 +302,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[u64], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -365,9 +337,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[IntValue<'ctx>], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -407,7 +379,7 @@ impl<'ctx> NDArrayType<'ctx> { let value = value.as_basic_value_enum(); assert_eq!(value.get_type(), self.dtype); - assert!(self.ndims.is_none_or(|ndims| ndims == 0)); + assert_eq!(self.ndims, 0); // We have to put the value on the stack to get a data pointer. let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap(); diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 9b71693a..c77e4571 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,13 +163,8 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - assert!( - ndarray.get_type().ndims().is_some(), - "NDIter requires ndims of NDArray to be known." - ); - let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 0c84b05e..1b99f464 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -101,12 +101,11 @@ impl<'ctx> NDArrayValue<'ctx> { target_ndims: u64, target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> Self { - assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = - NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) - .construct_uninitialized(generator, ctx, None); + let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, ctx, @@ -199,14 +198,13 @@ impl<'ctx> NDArrayType<'ctx> { ndarrays: &[NDArrayValue<'ctx>], ) -> BroadcastAllResult<'ctx, G> { assert!(!ndarrays.is_empty()); - assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); let llvm_usize = generator.get_size_type(ctx.ctx); // Infer the broadcast output ndims. let broadcast_ndims_int = - ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); - assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap(); + assert!(self.ndims() >= broadcast_ndims_int); let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); let broadcast_shape = ArraySliceValue::from_ptr_val( @@ -223,10 +221,7 @@ impl<'ctx> NDArrayType<'ctx> { let shape_entries = ndarrays .iter() .map(|ndarray| { - ( - ndarray.shape().as_slice_value(ctx, generator), - ndarray.get_type().ndims().unwrap(), - ) + (ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims()) }) .collect_vec(); broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index f3b03dd1..8eb700b9 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -121,9 +121,7 @@ impl<'ctx> NDArrayValue<'ctx> { .alloca_var(generator, ctx, self.name); // Set ndims and shape. - let ndims = self - .ndims - .map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false)); + let ndims = self.llvm_usize.const_int(self.ndims, false); result.store_ndims(ctx, ndims); let shape = self.shape(); @@ -180,7 +178,7 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) .construct_uninitialized(generator, ctx, carray.name); // Copy shape and update strides diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3d575028..3821f232 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -98,8 +98,8 @@ impl<'ctx> From> for PointerValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Get the expected `ndims` after indexing with `indices`. #[must_use] - fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option { - let mut ndims = self.ndims?; + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; for index in indices { match index { @@ -113,7 +113,7 @@ impl<'ctx> NDArrayValue<'ctx> { } } - Some(ndims) + ndims } /// Index into the ndarray, and return a newly-allocated view on this ndarray. @@ -127,8 +127,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, indices: &[RustNDIndex<'ctx>], ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - let dst_ndims = self.deduce_ndims_after_indexing_with(indices); let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index 88a94394..f802c0c0 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -29,16 +29,8 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), ) -> NDArrayValue<'ctx> { - assert!( - in_a.ndims.is_some_and(|ndims| ndims >= 2), - "in_a (which is {:?}) must be compile-time known and >= 2", - in_a.ndims - ); - assert!( - in_b.ndims.is_some_and(|ndims| ndims >= 2), - "in_b (which is {:?}) must be compile-time known and >= 2", - in_b.ndims - ); + assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims); + assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); @@ -47,13 +39,13 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); // Deduce ndims of the result of matmul. - let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims_int = max(in_a.ndims, in_b.ndims); let ndims = llvm_usize.const_int(ndims_int, false); // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the // destination ndarray to store the result of matmul. let (lhs, rhs, dst) = { - let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false); let in_lhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_a.shape().base_ptr(ctx, generator), @@ -63,7 +55,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( |_, _, val| val.into_int_value(), |_, _, val| val.into(), ); - let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false); let in_rhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_b.shape().base_ptr(ctx, generator), @@ -116,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { @@ -266,10 +258,7 @@ impl<'ctx> NDArrayValue<'ctx> { (out_dtype, out): (Type, NDArrayOut<'ctx>), ) -> Self { // Sanity check, but type inference should prevent this. - assert!( - self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), - "np.matmul disallows scalar input" - ); + assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input"); // If both arguments are 2-D they are multiplied like conventional matrices. // @@ -282,14 +271,14 @@ impl<'ctx> NDArrayValue<'ctx> { // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its // dimensions. After matrix multiplication the appended 1 is removed. - let new_a = if self.ndims.unwrap() == 1 { + let new_a = if self.ndims == 1 { // Prepend 1 to its dimensions self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) } else { *self }; - let new_b = if other.ndims.unwrap() == 1 { + let new_b = if other.ndims == 1 { // Append 1 to its dimensions other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) } else { @@ -305,12 +294,12 @@ impl<'ctx> NDArrayValue<'ctx> { let mut postindices = vec![]; let zero = ctx.ctx.i32_type().const_zero(); - if self.ndims.unwrap() == 1 { + if self.ndims == 1 { // Remove the prepended 1 postindices.push(RustNDIndex::SingleElement(zero)); } - if other.ndims.unwrap() == 1 { + if other.ndims == 1 { // Remove the appended 1 postindices.push(RustNDIndex::Ellipsis); postindices.push(RustNDIndex::SingleElement(zero)); diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 707c79a2..ad50d32e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -42,7 +42,7 @@ mod view; pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -62,7 +62,7 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn from_pointer_value( ptr: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { @@ -245,26 +245,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_shape = src_ndarray.shape().base_ptr(ctx, generator); self.copy_shape_from_array(generator, ctx, src_shape); @@ -296,26 +277,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_strides = src_ndarray.strides().base_ptr(ctx, generator); self.copy_strides_from_array(generator, ctx, src_strides); @@ -380,11 +342,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Self { - let clone = if self.ndims.is_some() { - self.get_type().construct_uninitialized(generator, ctx, None) - } else { - self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None) - }; + let clone = self.get_type().construct_uninitialized(generator, ctx, None); let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); @@ -437,11 +395,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.shape().get_typed_unchecked( @@ -459,7 +415,7 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } @@ -473,11 +429,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.strides().get_typed_unchecked( @@ -495,15 +449,15 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] - pub fn is_unsized(&self) -> Option { - self.ndims.map(|ndims| ndims == 0) + pub fn is_unsized(&self) -> bool { + self.ndims == 0 } /// Returns the element present in this `ndarray` if this is unsized. @@ -512,11 +466,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Option> { - let Some(is_unsized) = self.is_unsized() else { - panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - }; - - if is_unsized { + if self.is_unsized() { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; @@ -534,8 +484,6 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ScalarOrNDArray<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { ScalarOrNDArray::Scalar(unsized_elem) } else { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 450f7444..5027be58 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -26,9 +26,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndmin: u64, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - - let ndims = self.ndims.unwrap(); + let ndims = self.ndims; if ndims < ndmin { // Extend the dimensions with np.newaxis. @@ -67,13 +65,13 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); // Resolve negative indices let size = self.size(generator, ctx); - let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_shape = dst_ndarray.shape(); irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( generator, @@ -121,7 +119,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, axes: Option>, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); assert!( axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) ); @@ -130,7 +127,7 @@ impl<'ctx> NDArrayValue<'ctx> { let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); let axes = if let Some(axes) = axes { - let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false); + let num_axes = self.llvm_usize.const_int(self.ndims, false); // `axes = nullptr` if `axes` is unspecified. let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None); From e480081e4b9d95831958c65721ed42d2d32e3989 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 4 Jan 2025 10:28:27 +0800 Subject: [PATCH 38/46] update dependencies --- Cargo.lock | 103 +++++++++++++++++++++++++++-------------------------- flake.lock | 6 ++-- 2 files changed, 55 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ce59893f..646700d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.4" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ "shlex", ] @@ -170,7 +170,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -187,14 +187,14 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static", "libc", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] @@ -221,18 +221,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -249,18 +249,18 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -305,9 +305,9 @@ dependencies = [ [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "equivalent" @@ -378,9 +378,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "hashbrown" @@ -417,11 +417,11 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -472,7 +472,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -559,9 +559,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" @@ -678,7 +678,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", "trybuild", ] @@ -799,7 +799,7 @@ dependencies = [ "phf_shared 0.11.2", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -927,7 +927,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -940,14 +940,14 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1062,9 +1062,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -1095,29 +1095,29 @@ checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.134" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" dependencies = [ "itoa", "memchr", @@ -1226,7 +1226,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -1242,9 +1242,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" dependencies = [ "proc-macro2", "quote", @@ -1265,12 +1265,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1278,9 +1279,9 @@ dependencies = [ [[package]] name = "term" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0" +checksum = "a3bb6001afcea98122260987f8b7b5da969ecad46dbf0b5453702f776b491a41" dependencies = [ "home", "windows-sys 0.52.0", @@ -1325,7 +1326,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -1603,9 +1604,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" dependencies = [ "memchr", ] @@ -1637,5 +1638,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] diff --git a/flake.lock b/flake.lock index 67bb80e4..7672c219 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1733940404, - "narHash": "sha256-Pj39hSoUA86ZePPF/UXiYHHM7hMIkios8TYG29kQT4g=", + "lastModified": 1735834308, + "narHash": "sha256-dklw3AXr3OGO4/XT1Tu3Xz9n/we8GctZZ75ZWVqAVhk=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5d67ea6b4b63378b9c13be21e2ec9d1afc921713", + "rev": "6df24922a1400241dae323af55f30e4318a6ca65", "type": "github" }, "original": { From 8322d457c601a87fbaa2af735a819a0c1c4622e4 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 4 Jan 2025 15:30:24 +0800 Subject: [PATCH 39/46] standalone/demo: numpy2 compatibility --- nac3standalone/demo/interpret_demo.py | 8 ++++++-- nac3standalone/demo/src/numeric_primitives.py | 20 ++++++------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 8784ce53..fa91ed3c 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -67,7 +67,7 @@ def _bool(x): def _float(x): if isinstance(x, np.ndarray): - return np.float_(x) + return np.float64(x) else: return float(x) @@ -111,6 +111,9 @@ def patch(module): def output_strln(x): print(x, end='') + def output_int32_list(x): + print([int(e) for e in x]) + def dbg_stack_address(_): return 0 @@ -126,11 +129,12 @@ def patch(module): return output_float elif name == "output_str": return output_strln + elif name == "output_int32_list": + return output_int32_list elif name in { "output_bool", "output_int32", "output_int64", - "output_int32_list", "output_uint32", "output_uint64", "output_strln", diff --git a/nac3standalone/demo/src/numeric_primitives.py b/nac3standalone/demo/src/numeric_primitives.py index 77a641f8..2c08c16d 100644 --- a/nac3standalone/demo/src/numeric_primitives.py +++ b/nac3standalone/demo/src/numeric_primitives.py @@ -29,10 +29,10 @@ def u32_max() -> uint32: return ~uint32(0) def i32_min() -> int32: - return int32(1 << 31) + return int32(-(1 << 31)) def i32_max() -> int32: - return int32(~(1 << 31)) + return int32((1 << 31)-1) def u64_min() -> uint64: return uint64(0) @@ -63,8 +63,9 @@ def test_conv_from_i32(): i32_max() ]: output_int64(int64(x)) - output_uint32(uint32(x)) - output_uint64(uint64(x)) + if x >= 0: + output_uint32(uint32(x)) + output_uint64(uint64(x)) output_float64(float(x)) def test_conv_from_u32(): @@ -108,7 +109,6 @@ def test_conv_from_u64(): def test_f64toi32(): for x in [ - float(i32_min()) - 1.0, float(i32_min()), float(i32_min()) + 1.0, -1.5, @@ -117,7 +117,6 @@ def test_f64toi32(): 1.5, float(i32_max()) - 1.0, float(i32_max()), - float(i32_max()) + 1.0 ]: output_int32(int32(x)) @@ -138,24 +137,17 @@ def test_f64toi64(): def test_f64tou32(): for x in [ - -1.5, - float(u32_min()) - 1.0, - -0.5, float(u32_min()), 0.5, float(u32_min()) + 1.0, 1.5, float(u32_max()) - 1.0, float(u32_max()), - float(u32_max()) + 1.0 ]: output_uint32(uint32(x)) def test_f64tou64(): for x in [ - -1.5, - float(u64_min()) - 1.0, - -0.5, float(u64_min()), 0.5, float(u64_min()) + 1.0, @@ -181,4 +173,4 @@ def run() -> int32: test_f64tou32() test_f64tou64() - return 0 \ No newline at end of file + return 0 From d9c180ed13df48fcfcc2b6b148741c09732742bb Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 13:36:51 +0800 Subject: [PATCH 40/46] [artiq] symbol_resolver: Fix support for np.bool_ -> bool decay --- nac3artiq/demo/numpy_primitives_decay.py | 29 ++++++++++++++++++++++++ nac3artiq/src/symbol_resolver.rs | 22 +++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 nac3artiq/demo/numpy_primitives_decay.py diff --git a/nac3artiq/demo/numpy_primitives_decay.py b/nac3artiq/demo/numpy_primitives_decay.py new file mode 100644 index 00000000..957d363f --- /dev/null +++ b/nac3artiq/demo/numpy_primitives_decay.py @@ -0,0 +1,29 @@ +from min_artiq import * +import numpy +from numpy import int32 + + +@nac3 +class NumpyBoolDecay: + core: KernelInvariant[Core] + np_true: KernelInvariant[bool] + np_false: KernelInvariant[bool] + np_int: KernelInvariant[int32] + np_float: KernelInvariant[float] + np_str: KernelInvariant[str] + + def __init__(self): + self.core = Core() + self.np_true = numpy.True_ + self.np_false = numpy.False_ + self.np_int = numpy.int32(0) + self.np_float = numpy.float64(0.0) + self.np_str = numpy.str_("") + + @kernel + def run(self): + pass + + +if __name__ == "__main__": + NumpyBoolDecay().run() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 8e9cd10c..0f9d57bd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -931,10 +931,13 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - obj.extract::().map_or_else( - |_| Ok(Err(format!("{obj} is not in the range of bool"))), - |_| Ok(Ok(extracted_ty)), - ) + if let Ok(_) = obj.extract::() { + Ok(Ok(extracted_ty)) + } else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::() { + Ok(Ok(extracted_ty)) + } else { + Ok(Err(format!("{obj} is not in the range of bool"))) + } } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of float64"))), @@ -974,10 +977,14 @@ impl InnerResolver { let val: u64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); @@ -1413,9 +1420,12 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.uint64 { let val: u64 = obj.extract()?; Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract()?; + Ok(SymbolValue::Bool(val)) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract()?; Ok(SymbolValue::Str(val)) From 2271b46b9631f8b3892666fbdcefe9061621dd8c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:31:17 +0800 Subject: [PATCH 41/46] [core] codegen/values/ndarray: Fix Vec allocation --- nac3core/src/codegen/values/ndarray/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index ad50d32e..b32a8f63 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -950,10 +950,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, /// This function is used generating strides for globally defined contiguous ndarrays. #[must_use] pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { - let mut strides = Vec::with_capacity(ndims as usize); + let mut strides = vec![0u64; ndims as usize]; let mut stride_product = 1u64; - for i in 0..ndims { - let axis = ndims - i - 1; + for axis in (0..ndims).rev() { strides[axis as usize] = stride_product * itemsize; stride_product *= shape[axis as usize]; } From 4e21def1a026caa2f860bf80ae47607b5a103cff Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:36:35 +0800 Subject: [PATCH 42/46] [artiq] symbol_resolver: Add missing promotion for host compilation Shape tuple is always in i32, so a zero-extension to i64 is necessary when assigning the shape tuple into the shape field of the ndarray. --- nac3artiq/src/symbol_resolver.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0f9d57bd..04d224cd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1131,7 +1131,10 @@ impl InnerResolver { super::CompileError::new_err(format!("Error getting element {i}: {e}")) })? .unwrap(); - let value = value.into_int_value(); + let value = ctx + .builder + .build_int_z_extend(value.into_int_value(), llvm_usize, "") + .unwrap(); Ok(value) }) .collect::, PyErr>>()?; From 3c5e247195f52b2296cca20bf54267a017e24b1e Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:39:24 +0800 Subject: [PATCH 43/46] [artiq] symbol_resolver: Use TargetData to get size of dtype dtype.size_of() may not return a constant value. --- nac3artiq/src/symbol_resolver.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 04d224cd..7b70cc49 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1213,8 +1213,16 @@ impl InnerResolver { data_global.set_initializer(&data); // Get the constant itemsize. - let itemsize = dtype.size_of().unwrap(); - let itemsize = itemsize.get_zero_extended_constant().unwrap(); + // + // NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size` + // will always return a constant size. + let itemsize = ctx + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data().get_store_size(&dtype)) + .unwrap(); + assert_ne!(itemsize, 0); // Create the strides needed for ndarray.strides let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); From 352c7c880b3604f78b84d31f86114ce8c79b2830 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:41:21 +0800 Subject: [PATCH 44/46] [artiq] symbol_resolver: Fix incorrect global type for ndarray.strides --- nac3artiq/src/symbol_resolver.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 7b70cc49..75600cc2 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1232,7 +1232,7 @@ impl InnerResolver { // create a global for ndarray.strides and initialize it let strides_global = ctx.module.add_global( - llvm_i8.array_type(ndims as u32), + llvm_usize.array_type(ndims as u32), Some(AddressSpace::default()), &format!("${id_str}.strides"), ); From eaaa194a87082ee5530f724b45c9822466aafe56 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:42:53 +0800 Subject: [PATCH 45/46] [artiq] symbol_resolver: Cast ndarray.{shape,strides} globals to usize* This is needed as ndarray.{shapes,strides} are ArrayValues, and so a GEP operation is required to convert them into pointers to their first elements. --- nac3artiq/src/symbol_resolver.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 75600cc2..232cd084 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1248,9 +1248,30 @@ impl InnerResolver { let ndarray_ndims = llvm_usize.const_int(ndims, false); + // calling as_pointer_value on shape and strides returns [i64 x ndims]* + // convert into i64* to conform with expected layout of ndarray + let ndarray_shape = shape_global.as_pointer_value(); + let ndarray_shape = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_shape, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; let ndarray_strides = strides_global.as_pointer_value(); + let ndarray_strides = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_strides, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; let ndarray = llvm_ndarray .as_base_type() From 8baf111734c6c29504a0e1804b8f8d56600896f7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 17:08:03 +0800 Subject: [PATCH 46/46] [meta] Apply clippy suggestions --- nac3artiq/src/codegen.rs | 6 +++--- nac3artiq/src/lib.rs | 4 ++-- nac3artiq/src/symbol_resolver.rs | 6 +++--- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/toplevel/test.rs | 2 +- nac3core/src/typecheck/function_check.rs | 4 ++-- nac3core/src/typecheck/type_error.rs | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 4 ++-- nac3ld/src/dwarf.rs | 8 ++++---- nac3ld/src/lib.rs | 2 +- 11 files changed, 21 insertions(+), 21 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 156ba23e..9baa0afe 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -162,7 +162,7 @@ impl<'a> ArtiqCodeGenerator<'a> { } } -impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { +impl CodeGenerator for ArtiqCodeGenerator<'_> { fn get_name(&self) -> &str { &self.name } @@ -1505,7 +1505,7 @@ pub fn call_rtio_log_impl<'ctx>( /// Generates a call to `core_log`. pub fn gen_core_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, @@ -1522,7 +1522,7 @@ pub fn gen_core_log<'ctx>( /// Generates a call to `rtio_log`. pub fn gen_rtio_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ca2f2f15..1601d4a4 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -330,7 +330,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_core_log(ctx, &obj, fun, &args, generator)?; + gen_core_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), @@ -360,7 +360,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_rtio_log(ctx, &obj, fun, &args, generator)?; + gen_rtio_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 232cd084..1e99ac20 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -931,9 +931,9 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - if let Ok(_) = obj.extract::() { - Ok(Ok(extracted_ty)) - } else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::() { + if obj.extract::().is_ok() + || obj.call_method("__bool__", (), None)?.extract::().is_ok() + { Ok(Ok(extracted_ty)) } else { Ok(Err(format!("{obj} is not in the range of bool"))) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4b83e63e..e74e7ad8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -79,7 +79,7 @@ pub fn get_subst_key( .join(", ") } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Builds a sequence of `getelementptr` and `load` instructions which stores the value of a /// struct field into an LLVM value. pub fn build_gep_and_load( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 4e83d530..e74071bc 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -228,7 +228,7 @@ pub struct CodeGenContext<'ctx, 'a> { pub current_loc: Location, } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl CodeGenContext<'_, '_> { /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 1f33a4ba..6a83632e 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -807,7 +807,7 @@ struct TypeToStringFolder<'a> { unifier: &'a mut Unifier, } -impl<'a> Fold> for TypeToStringFolder<'a> { +impl Fold> for TypeToStringFolder<'_> { type TargetU = String; type Error = String; fn map_user(&mut self, user: Option) -> Result { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index ed801a17..2e655d11 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -15,7 +15,7 @@ use super::{ }; use crate::toplevel::helper::PrimDef; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)])) @@ -94,7 +94,7 @@ impl<'a> Inferencer<'a> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) - && !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id()) + && ty.obj_id(self.unifier).is_none_or(|id| id != PrimDef::List.id()) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { return Err(HashSet::from([format!( diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index 3cf5bdc4..b144f8a9 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -94,7 +94,7 @@ fn loc_to_str(loc: Option) -> String { } } -impl<'a> Display for DisplayTypeError<'a> { +impl Display for DisplayTypeError<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { use TypeErrorKind::*; let mut notes = Some(HashMap::new()); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 87692114..742fa197 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located) { } } -impl<'a> Fold<()> for Inferencer<'a> { +impl Fold<()> for Inferencer<'_> { type TargetU = Option; type Error = InferenceError; @@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> { type InferenceResult = Result; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { /// Constrain a <: b /// Currently implemented as unification fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { diff --git a/nac3ld/src/dwarf.rs b/nac3ld/src/dwarf.rs index e85a4e40..4bcccd32 100644 --- a/nac3ld/src/dwarf.rs +++ b/nac3ld/src/dwarf.rs @@ -30,7 +30,7 @@ pub struct DwarfReader<'a> { pub virt_addr: u32, } -impl<'a> DwarfReader<'a> { +impl DwarfReader<'_> { pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { DwarfReader { slice, virt_addr } } @@ -113,7 +113,7 @@ pub struct DwarfWriter<'a> { pub offset: usize, } -impl<'a> DwarfWriter<'a> { +impl DwarfWriter<'_> { pub fn new(slice: &mut [u8]) -> DwarfWriter { DwarfWriter { slice, offset: 0 } } @@ -375,7 +375,7 @@ pub struct FDE_Records<'a> { available: usize, } -impl<'a> Iterator for FDE_Records<'a> { +impl Iterator for FDE_Records<'_> { type Item = (u32, u32); fn next(&mut self) -> Option { @@ -423,7 +423,7 @@ pub struct EH_Frame_Hdr<'a> { fdes: Vec<(u32, u32)>, } -impl<'a> EH_Frame_Hdr<'a> { +impl EH_Frame_Hdr<'_> { /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// /// Load address is not known at this point. diff --git a/nac3ld/src/lib.rs b/nac3ld/src/lib.rs index 73e065dd..8ef8653a 100644 --- a/nac3ld/src/lib.rs +++ b/nac3ld/src/lib.rs @@ -159,7 +159,7 @@ struct SymbolTableReader<'a> { strtab: &'a [u8], } -impl<'a> SymbolTableReader<'a> { +impl SymbolTableReader<'_> { pub fn find_index_by_name(&self, sym_name: &[u8]) -> Option { self.symtab.iter().position(|sym| { if let Ok(dynsym_name) = name_starting_at_slice(self.strtab, sym.st_name as usize) {