From 49a7469b4a80d534cbf167b7d2b8d616e8d7c649 Mon Sep 17 00:00:00 2001 From: ram Date: Mon, 30 Dec 2024 13:02:09 +0800 Subject: [PATCH 01/80] use memcmp for string comparison Co-authored-by: ram Co-committed-by: ram --- nac3core/irrt/irrt.cpp | 2 + nac3core/irrt/irrt/string.hpp | 23 ++++++ nac3core/src/codegen/expr.rs | 108 ++++++---------------------- nac3core/src/codegen/irrt/mod.rs | 2 + nac3core/src/codegen/irrt/string.rs | 48 +++++++++++++ 5 files changed, 95 insertions(+), 88 deletions(-) create mode 100644 nac3core/irrt/irrt/string.hpp create mode 100644 nac3core/src/codegen/irrt/string.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 8447fc5a..722ed32d 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -4,7 +4,9 @@ #include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/string.hpp" diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 00000000..f695dcdc --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "irrt/int_types.hpp" + +namespace { +template +SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (len1 != len2){ + return 0; + } + return (__builtin_memcmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; +} +} // namespace + +extern "C" { +uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} + +uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0118ca43..c616449f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -24,7 +24,7 @@ use super::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_int_umin, call_memcpy_generic, + call_memcpy_generic, }, macros::codegen_unreachable, need_sret, numpy, @@ -2045,111 +2045,43 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let lhs = lhs.into_struct_value(); let rhs = rhs.into_struct_value(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(prhs, rhs).unwrap(); + let lhs_ptr = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); + + let rhs_ptr = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - - let len = call_int_umin(ctx, lhs_len, rhs_len, None); - - let current_bb = ctx.builder.get_insert_block().unwrap(); - let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); - ctx.builder.position_at_end(current_bb); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let lhs_char = { - let plhs_data = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - plhs_data, - &[i], - None - ).into_int_value() - }; - let rhs_char = { - let prhs_data = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - prhs_data, - &[i], - None - ).into_int_value() - }; - - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) - }, - |_, ctx| { - let bb = ctx.builder.get_insert_block().unwrap(); - cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let bb = ctx.builder.get_insert_block().unwrap(); - let is_len_eq = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_len, - rhs_len, - "", - ).unwrap(); - cmp_phi.add_incoming(&[(&is_len_eq, bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = cmp_phi.as_basic_value().into_int_value(); - - // Invert the final value if __ne__ + let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { - cmp_phi + result } } else if [left_ty, right_ty] .iter() diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 824921cd..21a16bdb 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -15,12 +15,14 @@ pub use list::*; pub use math::*; pub use range::*; pub use slice::*; +pub use string::*; mod list; mod math; pub mod ndarray; mod range; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 00000000..fb0f27b9 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,48 @@ +use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; +use itertools::Either; + +use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator}; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() { + 32 => ("nac3_str_eq", ctx.ctx.i32_type()), + 64 => ("nac3_str_eq64", ctx.ctx.i64_type()), + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + + let func = ctx.module.get_function(func_name).unwrap_or_else(|| { + ctx.module.add_function( + func_name, + return_type.fn_type( + &[ + str1_ptr.get_type().into(), + str1_len.get_type().into(), + str2_ptr.get_type().into(), + str2_len.get_type().into(), + ], + false, + ), + None, + ) + }); + let result = ctx + .builder + .build_call( + func, + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + "str_eq_call", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + generator.bool_to_i1(ctx, result) +} From 456aefa6ee1e71386e1dacc2a7e569a5ebc70f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Mon, 30 Dec 2024 13:03:31 +0800 Subject: [PATCH 02/80] clean up duplicate include --- nac3core/irrt/irrt.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 722ed32d..0d069869 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,3 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" -#include "irrt/string.hpp" From fbf0053c248a58421829b628b7c0f62be62a77f1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 30 Dec 2024 14:04:42 +0800 Subject: [PATCH 03/80] [core] irrt/string: Minor cleanup - Refactor __nac3_str_eq to always return bool - Use `get_usize_dependent_function_name` to get IRRT func name --- nac3core/irrt/irrt/string.hpp | 8 ++++---- nac3core/src/codegen/irrt/string.rs | 24 +++++++++++------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index f695dcdc..db3ad7fd 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -4,20 +4,20 @@ namespace { template -SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { +bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { if (len1 != len2){ return 0; } - return (__builtin_memcmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; + return __builtin_memcmp(str1, str2, static_cast(len1)) == 0; } } // namespace extern "C" { -uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { +bool nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } -uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { +bool nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index fb0f27b9..6ee40e45 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,7 +1,8 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; -use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator}; +use super::get_usize_dependent_function_name; +use crate::codegen::{CodeGenContext, CodeGenerator}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( @@ -12,16 +13,14 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( str2_ptr: PointerValue<'ctx>, str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { - let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() { - 32 => ("nac3_str_eq", ctx.ctx.i32_type()), - 64 => ("nac3_str_eq64", ctx.ctx.i64_type()), - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let llvm_i1 = ctx.ctx.bool_type(); - let func = ctx.module.get_function(func_name).unwrap_or_else(|| { + let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq"); + + let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { ctx.module.add_function( - func_name, - return_type.fn_type( + &func_name, + llvm_i1.fn_type( &[ str1_ptr.get_type().into(), str1_len.get_type().into(), @@ -33,8 +32,8 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( None, ) }); - let result = ctx - .builder + + ctx.builder .build_call( func, &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], @@ -43,6 +42,5 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) - .unwrap(); - generator.bool_to_i1(ctx, result) + .unwrap() } From 0e5940c49d0b71b4bd19e7e7b7fdb6a705f46dbe Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:43:24 +0800 Subject: [PATCH 04/80] [meta] Refactor itertools::{chain,enumerate,repeat_n} with std equiv --- nac3core/src/codegen/expr.rs | 4 ++-- nac3core/src/symbol_resolver.rs | 12 ++++++------ nac3core/src/toplevel/composer.rs | 2 +- nac3core/src/typecheck/typedef/mod.rs | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c616449f..606e7421 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -11,7 +11,7 @@ use inkwell::{ values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::{chain, izip, Either, Itertools}; +use itertools::{izip, Either, Itertools}; use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, @@ -1965,7 +1965,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } } - let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) + let cmp_val = izip!(once(&left).chain(comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; let (right_ty, rhs) = rhs; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index bab823c1..2378dd62 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -6,7 +6,7 @@ use std::{ }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; -use itertools::{chain, izip, Itertools}; +use itertools::{izip, Itertools}; use parking_lot::RwLock; use nac3parser::ast::{Constant, Expr, Location, StrRef}; @@ -452,11 +452,11 @@ pub fn parse_type_annotation( type_vars.len() )])); } - let fields = chain( - fields.iter().map(|(k, v, m)| (*k, (*v, *m))), - methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(); + let fields = fields + .iter() + .map(|(k, v, m)| (*k, (*v, *m))) + .chain(methods.iter().map(|(k, v, _)| (*k, (*v, false)))) + .collect(); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() })) } else { Err(HashSet::from([format!("Cannot use function name as type at {loc}")])) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 6040ced1..bd9a9214 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1052,7 +1052,7 @@ impl TopLevelComposer { } let mut result = Vec::new(); let no_defaults = args.args.len() - args.defaults.len() - 1; - for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) { + for (idx, x) in args.args.iter().skip(1).enumerate() { let type_ann = { let Some(annotation_expr) = x.node.annotation.as_ref() else {return Err(HashSet::from([format!("type annotation needed for `{}` (at {})", x.node.arg, x.location)]));}; parse_ast_to_type_annotation_kinds( diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 49cea04c..e190c4c4 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -3,13 +3,13 @@ use std::{ cell::RefCell, collections::{HashMap, HashSet}, fmt::{self, Display}, - iter::{repeat, zip}, + iter::{repeat, repeat_n, zip}, rc::Rc, sync::{Arc, Mutex}, }; use indexmap::IndexMap; -use itertools::{repeat_n, Itertools}; +use itertools::Itertools; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; From 35e3042435813ac07a129482cb7a382b9c840744 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 13:58:02 +0800 Subject: [PATCH 05/80] [core] Refactor/Remove redundant and unused constructs - Use ProxyValue.name where necessary - Remove NDArrayValue::ptr_to_{shape,strides} - Remove functions made obsolete by ndstrides - Remove use statement for ndarray::views as it only contain an impl block. - Remove class_names field in Resolvers of test sources --- nac3core/src/codegen/expr.rs | 324 +----------------- nac3core/src/codegen/llvm_intrinsics.rs | 33 +- nac3core/src/codegen/numpy.rs | 23 -- nac3core/src/codegen/test.rs | 15 +- nac3core/src/codegen/values/array.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 13 - nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/toplevel/builtins.rs | 146 +------- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/toplevel/test.rs | 18 +- .../src/typecheck/type_inferencer/test.rs | 4 - 15 files changed, 23 insertions(+), 573 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 606e7421..c3dfc80e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,19 +34,14 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::{NDArrayValue, RustNDIndex}, - ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, - TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::{extract_ndims, PrimDef}, - numpy::unpack_ndarray_var_tys, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -2444,319 +2439,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims_ty: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims_ty) else { - codegen_unreachable!(ctx) - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - let num_dims = extract_ndims(&ctx.unifier, ndims_ty) - 1; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let ndarray = - NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t, Some(num_dims)) - .construct_uninitialized(generator, ctx, None); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.shape().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.shape().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = ndarray::call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - unsafe { ndarray.create_data(generator, ctx) }; - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index 895339d0..9c360b6a 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -1,7 +1,6 @@ use inkwell::{ - context::Context, intrinsics::Intrinsic, - types::{AnyTypeEnum::IntType, FloatType}, + types::AnyTypeEnum::IntType, values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}, AddressSpace, }; @@ -9,34 +8,6 @@ use itertools::Either; use super::CodeGenContext; -/// Returns the string representation for the floating-point type `ft` when used in intrinsic -/// functions. -fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { - // Standard LLVM floating-point types - if ft == ctx.f16_type() { - return "f16"; - } - if ft == ctx.f32_type() { - return "f32"; - } - if ft == ctx.f64_type() { - return "f64"; - } - if ft == ctx.f128_type() { - return "f128"; - } - - // Non-standard floating-point types - if ft == ctx.x86_f80_type() { - return "f80"; - } - if ft == ctx.ppc_f128_type() { - return "ppcf128"; - } - - unreachable!() -} - /// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) /// intrinsic. pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { @@ -54,7 +25,7 @@ pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap(); } -/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) +/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic) /// intrinsic. pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { const FN_NAME: &str = "llvm.va_end"; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index cd113aae..513f8f51 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -604,29 +604,6 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( } } -/// Returns the number of dimensions for an array-like object as an [`IntValue`]. -fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (ty, value): (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match value { - BasicValueEnum::PointerValue(v) - if NDArrayValue::is_representable(v, llvm_usize).is_ok() => - { - NDArrayType::from_unifier_type(generator, ctx, ty).map_value(v, None).load_ndims(ctx) - } - - BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { - llvm_ndlist_get_ndims(generator, ctx, v.get_type()) - } - - _ => llvm_usize.const_zero(), - } -} - /// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 81c5836c..97bd3f09 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -36,7 +36,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: RwLock>, - class_names: HashMap, } impl Resolver { @@ -104,11 +103,9 @@ fn test_primitives() { let top_level = Arc::new(composer.make_top_level_context()); unifier.top_level = Some(top_level.clone()); - let resolver = Arc::new(Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }) as Arc; + let resolver = + Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) + as Arc; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let signature = FunSignature { @@ -298,11 +295,7 @@ fn test_simple_call() { loc: None, }))); - let resolver = Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }; + let resolver = Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }; resolver.add_id_def("foo".into(), DefinitionId(foo_id)); let resolver = Arc::new(resolver) as Arc; diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 78975f06..e6ebb258 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -389,7 +389,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + let var_name = name.or(self.2).map(|v| format!("{v}.addr")).unwrap_or_default(); unsafe { ctx.builder diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 12fd8634..eef74407 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -19,7 +19,6 @@ use crate::codegen::{ pub use contiguous::*; pub use indexing::*; pub use nditer::*; -pub use view::*; mod contiguous; mod indexing; @@ -113,12 +112,6 @@ impl<'ctx> NDArrayValue<'ctx> { self.get_type().get_fields(ctx.ctx).shape } - /// Returns the double-indirection pointer to the `shape` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name) - } - /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); @@ -147,12 +140,6 @@ impl<'ctx> NDArrayValue<'ctx> { self.get_type().get_fields(ctx.ctx).strides } - /// Returns the double-indirection pointer to the `strides` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name) - } - /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 45a82b38..218afc19 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -78,7 +78,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element(ctx).get(ctx, self.as_base_value(), None); + let p = self.element(ctx).get(ctx, self.as_base_value(), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -98,7 +98,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth(ctx).get(ctx, self.as_base_value(), None) + self.nth(ctx).get(ctx, self.as_base_value(), self.name) } /// Get the indices of the current element. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e382222c..36bd85d1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,13 +1,7 @@ use std::iter::once; use indexmap::IndexMap; -use inkwell::{ - attributes::{Attribute, AttributeLoc}, - types::{BasicMetadataTypeEnum, BasicType}, - values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, - IntPredicate, -}; -use itertools::Either; +use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ @@ -148,144 +142,6 @@ fn create_fn_by_codegen( } } -/// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic. -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function. -fn create_fn_by_intrinsic( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - intrinsic_fn: &'static str, -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - - ctx.module.add_function(intrinsic_fn, fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, args_val.as_slice(), name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - -/// Creates a unary NumPy [`TopLevelDef`] function using an extern function (e.g. from `libc` or -/// `libm`). -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation. -/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is -/// already implied by the C ABI. -fn create_fn_by_extern( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - extern_fn: &'static str, - attrs: &'static [&str], -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - let func = ctx.module.add_function(extern_fn, fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), - ); - - for attr in attrs { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - - func - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &args_val, name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo { BuiltinBuilder::new(unifier, primitives) .build_all_builtins() diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 1b0c9b80..93f2096b 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2621337c..d3301d00 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d0769305..911426b9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 5ebdf86c..d60daf83 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a9c4dd6..517f6846 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 37a3dede..1f33a4ba 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -15,14 +15,13 @@ use crate::{ symbol_resolver::{SymbolResolver, ValueEnum}, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{into_var_map, Type, Unifier}, + typedef::{Type, Unifier}, }, }; struct ResolverInternal { id_to_type: Mutex>, id_to_def: Mutex>, - class_names: Mutex>, } impl ResolverInternal { @@ -179,11 +178,8 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { let mut composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; - let internal_resolver = Arc::new(ResolverInternal { - id_to_def: Mutex::default(), - id_to_type: Mutex::default(), - class_names: Mutex::default(), - }); + let internal_resolver = + Arc::new(ResolverInternal { id_to_def: Mutex::default(), id_to_type: Mutex::default() }); let resolver = Arc::new(Resolver(internal_resolver.clone())) as Arc; @@ -784,13 +780,6 @@ fn make_internal_resolver_with_tvar( unifier: &mut Unifier, print: bool, ) -> Arc { - let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None); - let list = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimDef::List.id(), - fields: HashMap::new(), - params: into_var_map([list_elem_tvar]), - }); - let res: Arc = ResolverInternal { id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])), id_to_type: tvars @@ -806,7 +795,6 @@ fn make_internal_resolver_with_tvar( }) .collect::>() .into(), - class_names: Mutex::new(HashMap::from([("list".into(), list)])), } .into(); if print { diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index e56cb283..a6583536 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -18,7 +18,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: HashMap, - class_names: HashMap, } impl SymbolResolver for Resolver { @@ -198,7 +197,6 @@ impl TestEnvironment { let resolver = Arc::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: HashMap::default(), - class_names: HashMap::default(), }) as Arc; TestEnvironment { @@ -454,7 +452,6 @@ impl TestEnvironment { vars: IndexMap::default(), })), ); - let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into(); let id_to_name = [ "int32".into(), @@ -492,7 +489,6 @@ impl TestEnvironment { ("Bar2".into(), DefinitionId(defs + 3)), ] .into(), - class_names, }) as Arc; TestEnvironment { From 318371a5093f04afe9e029fa4b2cf2d60accea06 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 20 Dec 2024 13:19:40 +0800 Subject: [PATCH 06/80] [core] irrt: Minor cleanup --- nac3core/irrt/irrt/ndarray.hpp | 2 +- nac3core/irrt/irrt/ndarray/basic.hpp | 6 ++---- nac3core/irrt/irrt/ndarray/indexing.hpp | 9 ++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 9d305aa4..72ca0b9e 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -148,4 +148,4 @@ void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, NDIndexInt* out_idx) { __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); } -} // namespace \ No newline at end of file +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp index 05ee30fc..62c92ae2 100644 --- a/nac3core/irrt/irrt/ndarray/basic.hpp +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -6,8 +6,7 @@ #include "irrt/ndarray/def.hpp" namespace { -namespace ndarray { -namespace basic { +namespace ndarray::basic { /** * @brief Assert that `shape` does not contain negative dimensions. * @@ -247,8 +246,7 @@ void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); } } -} // namespace basic -} // namespace ndarray +} // namespace ndarray::basic } // namespace extern "C" { diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index 9e9e7b6e..76e78473 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -65,8 +65,7 @@ struct NDIndex { } // namespace namespace { -namespace ndarray { -namespace indexing { +namespace ndarray::indexing { /** * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) * @@ -162,7 +161,8 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ Range range = slice->indices_checked(src_ndarray->shape[src_axis]); - dst_ndarray->data = static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->data = + static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; dst_ndarray->shape[dst_axis] = (SizeT)range.len(); @@ -197,8 +197,7 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); } -} // namespace indexing -} // namespace ndarray +} // namespace ndarray::indexing } // namespace extern "C" { From 19122e29050315740302b3a2b7146e1aa3968ce0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 13 Dec 2024 16:35:34 +0800 Subject: [PATCH 07/80] [core] codegen: Rename classes/functions for consistency - ContiguousNDArrayFields -> ContiguousNDArrayStructFields - ndarray/nditer: Add _field suffix to field accessors --- nac3core/src/codegen/types/ndarray/contiguous.rs | 14 +++++++------- nac3core/src/codegen/values/ndarray/nditer.rs | 11 +++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 4401cb62..317539c0 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -31,7 +31,7 @@ pub struct ContiguousNDArrayType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct ContiguousNDArrayFields<'ctx> { +pub struct ContiguousNDArrayStructFields<'ctx> { #[value_type(usize)] pub ndims: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize.ptr_type(AddressSpace::default()))] @@ -40,12 +40,12 @@ pub struct ContiguousNDArrayFields<'ctx> { pub data: StructField<'ctx, PointerValue<'ctx>>, } -impl<'ctx> ContiguousNDArrayFields<'ctx> { +impl<'ctx> ContiguousNDArrayStructFields<'ctx> { #[must_use] pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let mut counter = FieldIndexCounter::default(); - ContiguousNDArrayFields { + ContiguousNDArrayStructFields { ndims: StructField::create(&mut counter, "ndims", llvm_usize), shape: StructField::create( &mut counter, @@ -72,7 +72,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { )); }; - let fields = ContiguousNDArrayFields::new(ctx, llvm_usize); + let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); check_struct_type_matches_fields( fields, @@ -93,14 +93,14 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { fn fields( item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, - ) -> ContiguousNDArrayFields<'ctx> { - ContiguousNDArrayFields::new_typed(item, llvm_usize) + ) -> ContiguousNDArrayStructFields<'ctx> { + ContiguousNDArrayStructFields::new_typed(item, llvm_usize) } /// See [`NDArrayType::fields`]. // TODO: Move this into e.g. StructProxyType #[must_use] - pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> { + pub fn get_fields(&self) -> ContiguousNDArrayStructFields<'ctx> { Self::fields(self.item, self.llvm_usize) } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 218afc19..4b31f274 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -69,7 +69,10 @@ impl<'ctx> NDIterValue<'ctx> { irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); } - fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + fn element_field( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> StructField<'ctx, PointerValue<'ctx>> { self.get_type().get_fields(ctx.ctx).element } @@ -78,7 +81,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element(ctx).get(ctx, self.as_base_value(), self.name); + let p = self.element_field(ctx).get(ctx, self.as_base_value(), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -91,14 +94,14 @@ impl<'ctx> NDIterValue<'ctx> { ctx.builder.build_load(p, "value").unwrap() } - fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + fn nth_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { self.get_type().get_fields(ctx.ctx).nth } /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth(ctx).get(ctx, self.as_base_value(), self.name) + self.nth_field(ctx).get(ctx, self.as_base_value(), self.name) } /// Get the indices of the current element. From dc413dfa4331ff68664ab10d45828d124b166bb7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 18:03:03 +0800 Subject: [PATCH 08/80] [core] codegen: Refactor TypedArrayLikeAdapter to use fn Allows for greater flexibility when TypedArrayLikeAdapter is used with custom value types. --- nac3core/src/codegen/irrt/ndarray/mod.rs | 20 ++- nac3core/src/codegen/numpy.rs | 2 +- nac3core/src/codegen/values/array.rs | 128 +++++++++--------- nac3core/src/codegen/values/list.rs | 4 +- nac3core/src/codegen/values/ndarray/mod.rs | 54 +++++--- nac3core/src/codegen/values/ndarray/nditer.rs | 17 +-- 6 files changed, 117 insertions(+), 108 deletions(-) diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index a05e0ce3..56d9094d 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -87,7 +87,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_void = ctx.ctx.void_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -129,8 +129,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ) } @@ -227,7 +227,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); @@ -326,11 +326,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) + TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into()) } /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] @@ -345,7 +341,7 @@ pub fn call_ndarray_calc_broadcast_index< ctx: &mut CodeGenContext<'ctx, '_>, array: NDArrayValue<'ctx>, broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { +) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); @@ -385,7 +381,7 @@ pub fn call_ndarray_calc_broadcast_index< TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ) } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 513f8f51..30a33f08 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -356,7 +356,7 @@ where ValueFn: Fn( &mut G, &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, + &TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, ) -> Result, String>, { ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index e6ebb258..9f3ec0e4 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -51,8 +51,8 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { /// This function should be called with a valid index. unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx>; @@ -76,8 +76,8 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn get_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> BasicValueEnum<'ctx> { @@ -107,8 +107,8 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn set_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: BasicValueEnum<'ctx>, ) { @@ -130,32 +130,33 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: } /// An array-like value that can have its array elements accessed as an arbitrary type `T`. -pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeAccessor<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeAccessor<'ctx, Index> { /// Casts an element from [`BasicValueEnum`] into `T`. fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T; /// # Safety /// /// This function should be called with a valid index. - unsafe fn get_typed_unchecked( + unsafe fn get_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> T { let value = unsafe { self.get_unchecked(ctx, generator, idx, name) }; - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } /// Returns the data at the `idx`-th index. - fn get_typed( + fn get_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, @@ -163,62 +164,62 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: name: Option<&str>, ) -> T { let value = self.get(ctx, generator, idx, name); - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } } /// An array-like value that can have its array elements mutated as an arbitrary type `T`. -pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeMutator<'ctx, Index> { /// Casts an element from T into [`BasicValueEnum`]. fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx>; /// # Safety /// /// This function should be called with a valid index. - unsafe fn set_typed_unchecked( + unsafe fn set_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); unsafe { self.set_unchecked(ctx, generator, idx, value) } } /// Sets the data at the `idx`-th index. - fn set_typed( + fn set_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); self.set(ctx, generator, idx, value); } } -/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. -type ValueDowncastFn<'ctx, T> = - Box, BasicValueEnum<'ctx>) -> T + 'ctx>; -/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. -type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; - /// An adapter for constraining untyped array values as typed values. -pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> { +pub struct TypedArrayLikeAdapter< + 'ctx, + G: CodeGenerator + ?Sized, + T, + Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>, +> { adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, } -impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { @@ -229,61 +230,62 @@ where /// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`]. pub fn from( adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, ) -> Self { TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn } } } -impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> ArrayLikeValue<'ctx> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { - fn element_type( + fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> AnyTypeEnum<'ctx> { self.adapted.element_type(ctx, generator) } - fn base_ptr( + fn base_ptr( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> PointerValue<'ctx> { self.adapted.base_ptr(ctx, generator) } - fn size( + fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } } -impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeIndexer<'ctx, Index>, { - unsafe fn ptr_offset_unchecked( + unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) } } - fn ptr_offset( + fn ptr_offset( &self, ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + generator: &mut CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -291,44 +293,46 @@ where } } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T { - (self.downcast_fn)(ctx, value) + (self.downcast_fn)(ctx, generator, value) } } -impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeMutator<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx> { - (self.upcast_fn)(ctx, value) + (self.upcast_fn)(ctx, generator, value) } } @@ -384,8 +388,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> { impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 7b1975f0..549bfe3f 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -199,8 +199,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index eef74407..0da3a2ee 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -478,8 +478,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -517,20 +517,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_ impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -570,8 +576,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -609,20 +615,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -667,8 +679,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -748,17 +760,19 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { let llvm_usize = generator.get_size_type(ctx.ctx); - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); + let indices_elem_ty = unsafe { + indices + .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .get_type() + .get_element_type() + }; let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { panic!("Expected list[int32] but got {indices_elem_ty}") }; diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 4b31f274..e29770e0 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -4,7 +4,7 @@ use inkwell::{ AddressSpace, }; -use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator}; +use super::{NDArrayValue, ProxyValue}; use crate::codegen::{ irrt, stmt::{gen_for_callback, BreakContinueHooks}, @@ -106,18 +106,13 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the indices of the current element. #[must_use] - pub fn get_indices( - &'ctx self, - ) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>> - { + pub fn get_indices( + &self, + ) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { TypedArrayLikeAdapter::from( self.indices, - Box::new(|ctx, val| { - ctx.builder - .build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "") - .unwrap() - }), - Box::new(|_, val| val.into()), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), ) } } From d5e8df070adc1c91595881aacc382b79ab887310 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 16:10:00 +0800 Subject: [PATCH 09/80] [core] Minor improvements to IRRT and add missing documentation --- nac3core/src/codegen/generator.rs | 1 + nac3core/src/codegen/irrt/list.rs | 52 ++++++---- nac3core/src/codegen/irrt/math.rs | 20 +++- nac3core/src/codegen/irrt/mod.rs | 11 ++- nac3core/src/codegen/irrt/ndarray/basic.rs | 88 ++++++++++++++--- nac3core/src/codegen/irrt/ndarray/indexing.rs | 5 + nac3core/src/codegen/irrt/ndarray/iter.rs | 20 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 99 ++++++++----------- nac3core/src/codegen/irrt/range.rs | 18 +++- nac3core/src/codegen/types/ndarray/nditer.rs | 19 ++-- nac3core/src/codegen/values/array.rs | 8 ++ nac3core/src/codegen/values/ndarray/mod.rs | 4 + nac3core/src/codegen/values/ndarray/nditer.rs | 4 + 13 files changed, 238 insertions(+), 111 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index f277ec9a..be007c2a 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -17,6 +17,7 @@ pub trait CodeGenerator { /// Return the module name for the code generator. fn get_name(&self) -> &str; + /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index a7fec59d..2c57f8e7 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,42 +24,52 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(dest_idx.0.get_type(), llvm_i32); + assert_eq!(dest_idx.1.get_type(), llvm_i32); + assert_eq!(dest_idx.2.get_type(), llvm_i32); + assert_eq!(src_idx.0.get_type(), llvm_i32); + assert_eq!(src_idx.1.get_type(), llvm_i32); + assert_eq!(src_idx.2.get_type(), llvm_i32); + + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); let slice_assign_fun = { let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step + llvm_i32.into(), // dest start idx + llvm_i32.into(), // dest end idx + llvm_i32.into(), // dest step elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step + llvm_i32.into(), // dest arr len + llvm_i32.into(), // src start idx + llvm_i32.into(), // src end idx + llvm_i32.into(), // src step elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size + llvm_i32.into(), // src arr len + llvm_i32.into(), // size ]; ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); + let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); ctx.module.add_function(fun_symbol, fn_t, None) }) }; - let zero = int32.const_zero(); - let one = int32.const_int(1, false); + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let dest_len = + ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + let src_len = + ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and @@ -136,7 +146,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( BasicTypeEnum::StructType(t) => t.size_of().unwrap(), _ => codegen_unreachable!(ctx), }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap() } .into(), ]; @@ -147,6 +157,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( .map(Either::unwrap_left) .unwrap() }; + // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); @@ -155,7 +166,8 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 4bc95913..33445b2a 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -62,8 +62,13 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isinf", fn_type, None) }); @@ -84,8 +89,13 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isnan", fn_type, None) }); @@ -104,6 +114,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gamma", fn_type, None) @@ -121,6 +133,8 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gammaln", fn_type, None) @@ -138,6 +152,8 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_j0", fn_type, None) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 21a16bdb..4cacdccb 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -132,10 +132,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( generator: &mut G, length: IntValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap(); + let llvm_i32 = ctx.ctx.i32_type(); + + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); + let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap(); Ok(Some(match (start, end, step) { (s, e, None) => ( if let Some(s) = s.as_ref() { @@ -144,7 +145,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( None => return Ok(None), } } else { - int32.const_zero() + llvm_i32.const_zero() }, { let e = if let Some(s) = e.as_ref() { diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 0daea1c4..d11c9b8d 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; @@ -7,19 +8,26 @@ use crate::codegen::{ expr::{create_and_call_function, infer_and_call_function}, irrt::get_usize_dependent_function_name, types::ProxyType, - values::{ndarray::NDArrayValue, ProxyValue}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`. +/// +/// Assets that `shape` does not contain negative dimensions. pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - shape: PointerValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -30,23 +38,37 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ctx, &name, Some(llvm_usize.into()), - &[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())], + &[ + (llvm_usize.into(), shape.size(ctx, generator).into()), + (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), + ], None, None, ); } +/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`. +/// +/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to +/// an `ndarray`. pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndarray_ndims: IntValue<'ctx>, - ndarray_shape: PointerValue<'ctx>, - output_ndims: IntValue<'ctx>, - output_shape: IntValue<'ctx>, + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -58,16 +80,20 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_ndims.into()), - (llvm_pusize.into(), ndarray_shape.into()), - (llvm_usize.into(), output_ndims.into()), - (llvm_pusize.into(), output_shape.into()), + (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), + (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), + (llvm_usize.into(), output_shape.size(ctx, generator).into()), + (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), ], None, None, ); } +/// Generates a call to `__nac3_ndarray_size`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an +/// `ndarray`, corresponding to the value of `ndarray.size`. pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -90,6 +116,10 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_nbytes`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the +/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -112,6 +142,10 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_len`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of +/// the `ndarray`, corresponding to the value of `ndarray.__len__`. pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -134,6 +168,9 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_is_c_contiguous`. +/// +/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -156,6 +193,9 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_nth_pelement`. +/// +/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -167,6 +207,8 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!(index.get_type(), llvm_usize); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( @@ -181,11 +223,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`. +/// +/// `indices` must have the same number of elements as the number of dimensions in `ndarray`. +/// +/// Returns a [`PointerValue`] to the element indexed by `indices`. pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: PointerValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); @@ -193,6 +240,11 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); @@ -202,7 +254,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized Some(llvm_pi8.into()), &[ (llvm_ndarray.into(), ndarray.as_base_value().into()), - (llvm_pusize.into(), indices.into()), + (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), None, @@ -211,6 +263,9 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized .unwrap() } +/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. +/// +/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -231,6 +286,11 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_ndarray_copy_data`. +/// +/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number +/// of elements in `src_ndarray` must be greater than or equal to the number of elements in +/// `dst_ndarray`. pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 0821b2cd..3e2c908d 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -5,6 +5,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_index`. +/// +/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) +/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the +/// operation `dst_ndarray = src_ndarray[indices]`. pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 966d6605..47cd5b29 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue}, AddressSpace, }; @@ -9,21 +10,29 @@ use crate::codegen::{ types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArrayLikeValue, ArraySliceValue, ProxyValue, + ProxyValue, TypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_nditer_initialize`. +/// +/// Initializes the `iter` object. pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ndarray: NDArrayValue<'ctx>, - indices: ArraySliceValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); create_and_call_function( @@ -40,6 +49,10 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_nditer_initialize_has_element`. +/// +/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` +/// object. pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -59,6 +72,9 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_nditer_next`. +/// +/// Moves `iter` to point to the next element. pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56d9094d..b74ace0f 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,10 +1,11 @@ use inkwell::{ - types::IntType, + types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; +use super::get_usize_dependent_function_name; use crate::codegen::{ llvm_intrinsics, macros::codegen_unreachable, @@ -23,8 +24,8 @@ mod basic; mod indexing; mod iter; -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. +/// Generates a call to `__nac3_ndarray_calc_size`. Returns a +/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, @@ -43,18 +44,22 @@ where let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); + assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); + assert_eq!( + BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let ndarray_calc_size_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], false, ); let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { + ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); @@ -76,10 +81,10 @@ where .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] /// containing `i32` indices of the flattened index. /// -/// * `index` - The index to compute the multidimensional index for. +/// * `index` - The `llvm_usize` index to compute the multidimensional index for. /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// `NDArray`. pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( @@ -94,19 +99,18 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert_eq!(index.get_type(), llvm_usize); + + let ndarray_calc_nd_indices_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let fn_type = llvm_void.fn_type( &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], false, ); - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -134,15 +138,21 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for +/// the multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx, G, Index>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: &Indices, + indices: &Index, ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, + Index: ArrayLikeIndexer<'ctx>, { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -163,19 +173,16 @@ where "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" ); - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_flatten_index_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], false, ); - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -201,27 +208,8 @@ where index } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] +/// containing the size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -231,13 +219,10 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_calc_broadcast_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[ llvm_pusize.into(), @@ -249,7 +234,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( false, ); - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) }); let lhs_ndims = lhs.load_ndims(ctx); diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 47c63c4f..3b6bc31d 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -6,6 +6,13 @@ use itertools::Either; use crate::codegen::{CodeGenContext, CodeGenerator}; +/// Invokes the `__nac3_range_slice_len` in IRRT. +/// +/// - `start`: The `i32` start value for the slice. +/// - `end`: The `i32` end value for the slice. +/// - `step`: The `i32` step value for the slice. +/// +/// Returns an `i32` value of the length of the slice. pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -14,9 +21,15 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( step: IntValue<'ctx>, ) -> IntValue<'ctx> { const SYMBOL: &str = "__nac3_range_slice_len"; + + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(start.get_type(), llvm_i32); + assert_eq!(end.get_type(), llvm_i32); + assert_eq!(step.get_type(), llvm_i32); + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); @@ -33,6 +46,7 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); + ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .map(CallSiteValue::try_as_basic_value) diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c9b6b7d5..772d5b23 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -14,7 +14,7 @@ use crate::codegen::{ types::structure::{check_struct_type_matches_fields, StructField, StructFields}, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArraySliceValue, ProxyValue, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -128,6 +128,11 @@ impl<'ctx> NDIterType<'ctx> { } /// Allocate an [`NDIter`] that iterates through the given `ndarray`. + /// + /// Note: This function allocates an array on the stack at the current builder location, which + /// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to + /// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the + /// [`NDIter`] is no longer needed. #[must_use] pub fn construct( &self, @@ -141,16 +146,12 @@ impl<'ctx> NDIterType<'ctx> { // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap(); + let indices = + TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = >::Value::from_pointer_value( - nditer, - ndarray, - indices, - self.llvm_usize, - None, - ); + let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); - irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices); + irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); nditer } diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 9f3ec0e4..55e91b21 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -265,6 +265,14 @@ where ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } + + fn as_slice_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, + ) -> ArraySliceValue<'ctx> { + self.adapted.as_slice_value(ctx, generator) + } } impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 0da3a2ee..4c5be432 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -358,6 +358,10 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); } + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and + /// copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. #[must_use] pub fn make_copy( &self, diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index e29770e0..4b4e07a1 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -141,6 +141,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// get properties of the current iteration (e.g., the current element, indices, etc.) + /// + /// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and + /// after invoking this function respectively. See [`NDIterType::construct`] for an explanation + /// on why this is suggested. pub fn foreach<'a, G, F>( &self, generator: &mut G, From 3c0ce3031fb21476d7dedd6109d00d61e81fba61 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 13:44:14 +0800 Subject: [PATCH 10/80] [core] codegen: Update raw_alloca to return PointerValue Better match the expected behavior of alloca. --- nac3core/src/codegen/types/list.rs | 4 ++-- nac3core/src/codegen/types/mod.rs | 11 ++++++++--- nac3core/src/codegen/types/ndarray/contiguous.rs | 2 +- nac3core/src/codegen/types/ndarray/indexing.rs | 2 +- nac3core/src/codegen/types/ndarray/mod.rs | 2 +- nac3core/src/codegen/types/ndarray/nditer.rs | 2 +- nac3core/src/codegen/types/range.rs | 4 ++-- nac3core/src/codegen/types/utils/slice.rs | 4 ++-- 8 files changed, 18 insertions(+), 13 deletions(-) diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 3d041349..8de30867 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; @@ -167,7 +167,7 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 022f897b..03d6d387 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -16,7 +16,11 @@ //! the returned object. This is similar to a `new` expression in C++ but the object is allocated //! on the stack. -use inkwell::{context::Context, types::BasicType, values::IntValue}; +use inkwell::{ + context::Context, + types::BasicType, + values::{IntValue, PointerValue}, +}; use super::{ values::{ArraySliceValue, ProxyValue}, @@ -53,13 +57,14 @@ pub trait ProxyType<'ctx>: Into { llvm_ty: Self::Base, ) -> Result<(), String>; - /// Creates a new value of this type, returning the LLVM instance of this value. + /// Creates a new value of this type by invoking `alloca`, returning a [`PointerValue`] instance + /// representing the allocated value. fn raw_alloca( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base; + ) -> PointerValue<'ctx>; /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the /// resulting array. diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 317539c0..4be55474 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -218,7 +218,7 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 959d4f57..6bbd3e8b 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -176,7 +176,7 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 4127ffa8..b655f352 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -430,7 +430,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 772d5b23..ed98aa4b 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -203,7 +203,7 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index e704455b..dc241b6e 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; @@ -131,7 +131,7 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index aba0efb2..dd7643f7 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; use itertools::Itertools; @@ -215,7 +215,7 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base { + ) -> PointerValue<'ctx> { generator .gen_var_alloc( ctx, From dc9efa9e8c34eb398c2f28b828df122a70615591 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:08 +0800 Subject: [PATCH 11/80] [core] codegen/ndarray: Use IRRT for size() and indexing operations Also refactor some usages of call_ndarray_calc_size with ndarray.size(). --- nac3core/irrt/irrt/ndarray.hpp | 31 ------- nac3core/src/codegen/builtin_fns.rs | 14 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 76 +---------------- nac3core/src/codegen/numpy.rs | 98 +++++++++++++++++----- nac3core/src/codegen/values/ndarray/mod.rs | 80 ++++-------------- 5 files changed, 106 insertions(+), 193 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 72ca0b9e..7fc9a63b 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n } } -template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, - SizeT num_dims, - const NDIndexInt* indices, - SizeT num_indices) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - template void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims, @@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } -uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, - uint64_t num_dims, - const NDIndexInt* indices, - uint64_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims, const uint32_t* rhs_dims, diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a41b9f55..b21e721e 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = - irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); + let n_sz = n.size(generator, ctx); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx .builder @@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( llvm_int64.const_int(1, false), (n_sz, false), |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; + let elem = unsafe { + n.data().get_unchecked( + ctx, + generator, + &ctx.builder + .build_int_truncate_or_bit_cast(idx, llvm_usize, "") + .unwrap(), + None, + ) + }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index b74ace0f..56017c94 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{BasicTypeEnum, IntType}, + types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; @@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for -/// the multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); - let ndarray_flatten_index_fn = - ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] -/// containing the size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 30a33f08..9328bb83 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,8 +21,8 @@ use super::{ stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, types::{ndarray::NDArrayType, ListType, ProxyType}, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, + ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, @@ -318,12 +318,7 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); + let ndarray_num_elems = ndarray.size(generator, ctx); gen_for_callback_incrementing( generator, @@ -434,6 +429,66 @@ where rhs_val.get_type() ); + // Returns the element of an ndarray indexed by the given indices, performing int-promotion on + // `indices` where necessary. + // + // Required for compatibility with `NDArrayType::get_unchecked`. + let get_data_by_indices_compat = + |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT + let stackptr = llvm_intrinsics::call_stacksave(ctx, None); + let indices = if llvm_usize == ctx.ctx.i32_type() { + indices + } else { + let indices_usize = TypedArrayLikeAdapter::>::from( + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") + .unwrap(), + indices.size(ctx, generator), + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (indices.size(ctx, generator), false), + |generator, ctx, _, i| { + let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; + let idx = ctx + .builder + .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") + .unwrap(); + unsafe { + indices_usize.set_typed_unchecked(ctx, generator, &i, idx); + } + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + indices_usize + }; + + let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; + + llvm_intrinsics::call_stackrestore(ctx, stackptr); + + elem + }; + // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) @@ -455,7 +510,7 @@ where .map_value(lhs_val.into_pointer_value(), None); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } + get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) }; let rhs_elem = if rhs_scalar { @@ -465,7 +520,7 @@ where .map_value(rhs_val.into_pointer_value(), None); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } + get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) }; value_fn(generator, ctx, (lhs_elem, rhs_elem)) @@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); if cfg!(debug_assertions) { @@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() }; let idx0 = unsafe { let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() }; let idx1 = unsafe { let idx1 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() }; let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; @@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( generator, ctx, None, - llvm_i32.const_zero(), + llvm_usize.const_zero(), (common_dim, false), |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - let ab_idx = generator.gen_array_var_alloc( ctx, - llvm_i32.into(), + llvm_usize.into(), llvm_usize.const_int(2, false), None, )?; @@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); // Dimensions are reversed in the transposed array let out = create_ndarray_dyn_shape( @@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ); // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); + let out_sz = out.size(generator, ctx); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), @@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n1_sz = n1.size(generator, ctx); + let n2_sz = n2.size(generator, ctx); ctx.make_assert( generator, diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 4c5be432..e47876c6 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -5,8 +5,8 @@ use inkwell::{ }; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ irrt, @@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_ndarray_calc_size( - generator, - ctx, - &self.as_slice_value(ctx, generator), - (None, None), - ) + irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) } } @@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - idx.get_type(), - "", - ) - .unwrap(); - let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[idx], - name.unwrap_or_default(), - ) - .unwrap() - }; + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } @@ -769,52 +747,28 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); - let indices_elem_ty = unsafe { - indices - .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type() - }; - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() + let indices = TypedArrayLikeAdapter::from( + indices.as_slice_value(ctx, generator), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ); - let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices); - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - index.get_type(), - "", - ) - .unwrap(); - let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); + let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices( + generator, ctx, *self.0, &indices, + ); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - }; - // TODO: Current implementation is transparent + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. ctx.builder .build_pointer_cast( ptr, BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } From 2f0847d77b428d5c38dcb0fbc7abb3724a202f52 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 16:48:31 +0800 Subject: [PATCH 12/80] [core] codegen/types: Refactor ProxyType - Add alloca_type() function to obtain the type that should be passed into a `build_alloca` call - Provide default implementations for raw_alloca and array_alloca - Add raw_alloca_var and array_alloca_var to distinguish alloca instructions placed at the front of the function vs at the current builder location --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/types/list.rs | 57 +++++++---------- nac3core/src/codegen/types/mod.rs | 58 +++++++++++++++--- .../src/codegen/types/ndarray/contiguous.rs | 61 ++++++++----------- .../src/codegen/types/ndarray/indexing.rs | 56 +++++++---------- nac3core/src/codegen/types/ndarray/mod.rs | 60 ++++++++---------- nac3core/src/codegen/types/ndarray/nditer.rs | 61 +++++++++---------- nac3core/src/codegen/types/range.rs | 51 ++++++---------- nac3core/src/codegen/types/utils/slice.rs | 61 ++++++++----------- .../src/codegen/values/ndarray/contiguous.rs | 2 +- .../src/codegen/values/ndarray/indexing.rs | 2 +- 11 files changed, 224 insertions(+), 247 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c3dfc80e..4b781d26 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1108,7 +1108,7 @@ pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( // List structure; type { ty*, size_t } let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca(generator, ctx, name); + let list = arr_ty.alloca_var(generator, ctx, name); let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); list.store_size(ctx, generator, length); diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 8de30867..6608a808 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,13 +1,12 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::codegen::{ - values::{ArraySliceValue, ListValue, ProxyValue}, + values::{ListValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -113,15 +112,33 @@ impl<'ctx> ListType<'ctx> { } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.llvm_usize, name, ) @@ -162,36 +179,8 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 03d6d387..98bd43ba 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -57,24 +57,66 @@ pub trait ProxyType<'ctx>: Into { llvm_ty: Self::Base, ) -> Result<(), String>; - /// Creates a new value of this type by invoking `alloca`, returning a [`PointerValue`] instance - /// representing the allocated value. - fn raw_alloca( + /// Returns the type that should be used in `alloca` IR statements. + fn alloca_type(&self) -> impl BasicType<'ctx>; + + /// Creates a new value of this type by invoking `alloca` at the current builder location, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> PointerValue<'ctx> { + ctx.builder + .build_alloca(self.alloca_type().as_basic_type_enum(), name.unwrap_or_default()) + .unwrap() + } + + /// Creates a new value of this type by invoking `alloca` at the beginning of the function, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> PointerValue<'ctx>; + ) -> PointerValue<'ctx> { + generator.gen_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), name).unwrap() + } - /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the - /// resulting array. - fn array_alloca( + /// Creates a new array value of this type by invoking `alloca` at the current builder location, + /// returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca( + self.alloca_type().as_basic_type_enum(), + size, + name.unwrap_or_default(), + ) + .unwrap(), + size, + name, + ) + } + + /// Creates a new array value of this type by invoking `alloca` at the beginning of the + /// function, returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, size: IntValue<'ctx>, name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx>; + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), size, name) + .unwrap() + } /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 4be55474..e5fb8cdc 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -16,7 +16,7 @@ use crate::{ }, ProxyType, }, - values::{ndarray::ContiguousNDArrayValue, ArraySliceValue, ProxyValue}, + values::{ndarray::ContiguousNDArrayValue, ProxyValue}, CodeGenContext, CodeGenerator, }, toplevel::numpy::unpack_ndarray_var_tys, @@ -157,16 +157,37 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { Self { ty: ptr_ty, item, llvm_usize } } - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.item, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.item, self.llvm_usize, name, @@ -213,36 +234,8 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 6bbd3e8b..644e173c 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -90,15 +90,33 @@ impl<'ctx> NDIndexType<'ctx> { Self { ty: ptr_ty, llvm_usize } } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.llvm_usize, name, ) @@ -114,7 +132,7 @@ impl<'ctx> NDIndexType<'ctx> { ) -> ArraySliceValue<'ctx> { // Allocate the LLVM ndindices. let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false); - let ndindices = self.array_alloca(generator, ctx, num_ndindices, None); + let ndindices = self.array_alloca_var(generator, ctx, num_ndindices, None); // Initialize all of them. for (i, in_ndindex) in in_ndindices.iter().enumerate() { @@ -171,36 +189,8 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index b655f352..3886ce84 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -14,7 +14,7 @@ use super::{ }; use crate::{ codegen::{ - values::{ndarray::NDArrayValue, ArraySliceValue, ProxyValue, TypedArrayLikeMutator}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeMutator}, {CodeGenContext, CodeGenerator}, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, @@ -182,15 +182,35 @@ impl<'ctx> NDArrayType<'ctx> { } /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.dtype, self.ndims, self.llvm_usize, @@ -214,7 +234,7 @@ impl<'ctx> NDArrayType<'ctx> { ndims: IntValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { - let ndarray = self.alloca(generator, ctx, name); + let ndarray = self.alloca_var(generator, ctx, name); let itemsize = ctx .builder @@ -425,36 +445,8 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index ed98aa4b..7ce8ed79 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -109,8 +109,31 @@ impl<'ctx> NDIterType<'ctx> { self.llvm_usize } + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + parent, + indices, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -119,7 +142,7 @@ impl<'ctx> NDIterType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), parent, indices, self.llvm_usize, @@ -140,7 +163,7 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - let nditer = self.raw_alloca(generator, ctx, None); + let nditer = self.raw_alloca_var(generator, ctx, None); let ndims = ndarray.load_ndims(ctx); // The caller has the responsibility to allocate 'indices' for `NDIter`. @@ -198,36 +221,8 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index dc241b6e..bdd4e79c 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,13 +1,12 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::codegen::{ - values::{ArraySliceValue, ProxyValue, RangeValue}, + values::{ProxyValue, RangeValue}, {CodeGenContext, CodeGenerator}, }; @@ -78,15 +77,29 @@ impl<'ctx> RangeType<'ctx> { } /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(self.raw_alloca(ctx, name), name) + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), name, ) } @@ -126,36 +139,8 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { Self::is_representable(llvm_ty) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index dd7643f7..fa5a3474 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + values::IntValue, AddressSpace, }; use itertools::Itertools; @@ -15,7 +15,7 @@ use crate::codegen::{ }, ProxyType, }, - values::{utils::SliceValue, ArraySliceValue, ProxyValue}, + values::{utils::SliceValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -154,16 +154,35 @@ impl<'ctx> SliceType<'ctx> { self.int_ty } - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.int_ty, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.int_ty, self.llvm_usize, name, @@ -210,36 +229,8 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { Self::is_representable(llvm_ty, generator.get_size_type(ctx)) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> PointerValue<'ctx> { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 87e2f1d8..f3b03dd1 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -118,7 +118,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca(generator, ctx, self.name); + .alloca_var(generator, ctx, self.name); // Set ndims and shape. let ndims = self diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 69c00807..3d575028 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -248,7 +248,7 @@ impl<'ctx> RustNDIndex<'ctx> { RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) - .alloca(generator, ctx, None); + .alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( From 1ffe2fcc7ff826f20e9b5c127ee1133cdca54aa2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jan 2025 15:28:28 +0800 Subject: [PATCH 13/80] [core] irrt: Minor reformat --- nac3core/irrt/irrt/math.hpp | 2 ++ nac3core/irrt/irrt/string.hpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nac3core/irrt/irrt/math.hpp b/nac3core/irrt/irrt/math.hpp index 1872f564..9dc1377e 100644 --- a/nac3core/irrt/irrt/math.hpp +++ b/nac3core/irrt/irrt/math.hpp @@ -1,5 +1,7 @@ #pragma once +#include "irrt/int_types.hpp" + namespace { // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index db3ad7fd..229b7509 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -5,7 +5,7 @@ namespace { template bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { - if (len1 != len2){ + if (len1 != len2) { return 0; } return __builtin_memcmp(str1, str2, static_cast(len1)) == 0; From 805a9d23b353d078c18deb8902f4b77854a7167f Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 3 Jan 2025 13:58:46 +0800 Subject: [PATCH 14/80] [core] codegen: Add derive(Copy, Clone) to TypedArrayLikeAdapter --- nac3core/src/codegen/values/array.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 55e91b21..b756f278 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -208,6 +208,7 @@ pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntV } /// An adapter for constraining untyped array values as typed values. +#[derive(Copy, Clone)] pub struct TypedArrayLikeAdapter< 'ctx, G: CodeGenerator + ?Sized, From 822f9d33f86693df45e532bde65165751cf9886e Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 13 Dec 2024 16:05:54 +0800 Subject: [PATCH 15/80] [core] codegen: Refactor ListType to use derive(StructFields) --- nac3core/src/codegen/expr.rs | 68 +++--- nac3core/src/codegen/numpy.rs | 14 +- nac3core/src/codegen/types/list.rs | 262 +++++++++++++++++++----- nac3core/src/codegen/types/structure.rs | 15 ++ nac3core/src/codegen/values/list.rs | 64 ++---- 5 files changed, 285 insertions(+), 138 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4b781d26..8d7f8e35 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1090,33 +1090,6 @@ pub fn destructure_range<'ctx>( (start, end, step) } -/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting -/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. -/// -/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element -/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to -/// generate a sized list with an unknown element type. -pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Option>, - length: IntValue<'ctx>, - name: Option<&'ctx str>, -) -> ListValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ty.unwrap_or(llvm_usize.into()); - - // List structure; type { ty*, size_t } - let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca_var(generator, ctx, name); - - let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); - list.store_size(ctx, generator, length); - list.create_data(ctx, llvm_elem_ty, None); - - list -} - /// Generates LLVM IR for a [list comprehension expression][expr]. pub fn gen_comprehension<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1189,12 +1162,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = allocate_list( + list = ListType::new(generator, ctx.ctx, elem_ty).construct( generator, ctx, - Some(elem_ty), list_alloc_size.into_int_value(), - Some("listcomp.addr"), + Some("listcomp"), ); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); @@ -1241,7 +1213,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + list = ListType::new(generator, ctx.ctx, elem_ty).construct( + generator, + ctx, + length, + Some("listcomp"), + ); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1406,7 +1383,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); + let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) + .construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1493,10 +1471,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = allocate_list( + let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( generator, ctx, - Some(elem_llvm_ty), ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), None, ); @@ -2553,7 +2530,20 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Some(elements[0].get_type()) }; let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); - let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); + let arr_str_ptr = if let Some(ty) = ty { + ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("list"), + ) + } else { + ListType::new_untyped(generator, ctx.ctx).construct_empty( + generator, + ctx, + Some("list"), + ) + }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { let elem_ptr = arr_ptr.ptr_offset( @@ -3031,8 +3021,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = - allocate_list(generator, ctx, Some(ty), length, Some("ret")); + let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("ret"), + ); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9328bb83..9b5af0f1 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, + types::{BasicType, BasicTypeEnum, PointerType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -639,17 +639,17 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); + let list_elem_ty = list_ty.element_type().unwrap(); let ndims = llvm_usize.const_int(1, false); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Getting ndims for list[ndarray] not supported") @@ -670,10 +670,10 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let list_elem_ty = src_lst.get_type().element_type(); + let list_elem_ty = src_lst.get_type().element_type().unwrap(); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { // The stride of elements in this dimension, i.e. the number of elements between arr[i] @@ -733,7 +733,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( )?; } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Not implemented for list[ndarray]") diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 6608a808..337d049c 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,68 +1,113 @@ use inkwell::{ - context::Context, + context::{AsContextRef, Context}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - AddressSpace, + values::{IntValue, PointerValue}, + AddressSpace, IntPredicate, OptimizationLevel, }; +use itertools::Itertools; + +use nac3core_derive::StructFields; use super::ProxyType; -use crate::codegen::{ - values::{ListValue, ProxyValue}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + types::structure::{ + check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + }, + values::{ListValue, ProxyValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, }; /// Proxy type for a `list` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct ListType<'ctx> { ty: PointerType<'ctx>, + item: Option>, llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ListStructFields<'ctx> { + /// Array pointer to content. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub items: StructField<'ctx, PointerValue<'ctx>>, + + /// Number of items in the array. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ListStructFields<'ctx> { + #[must_use] + pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + ListStructFields { + items: StructField::create( + &mut counter, + "items", + item.ptr_type(AddressSpace::default()), + ), + len: StructField::create(&mut counter, "len", llvm_usize), + } + } +} + impl<'ctx> ListType<'ctx> { /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. pub fn is_representable( llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Result<(), String> { - let llvm_list_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); - }; - if llvm_list_ty.count_fields() != 2 { - return Err(format!( - "Expected 2 fields in `list`, got {}", - llvm_list_ty.count_fields() - )); - } + let ctx = llvm_ty.get_context(); - let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); - let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); + let llvm_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); }; - let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); - let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); - }; - if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width() - )); - } + let fields = ListStructFields::new(ctx, llvm_usize); - Ok(()) + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { + ListStructFields::new_typed(item, llvm_usize) + } + + /// See [`ListType::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, _ctx: &impl AsContextRef<'ctx>) -> ListStructFields<'ctx> { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( ctx: &'ctx Context, - element_type: BasicTypeEnum<'ctx>, + element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - // struct List { data: T*, size: size_t } - let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()]; + let element_type = element_type.unwrap_or(llvm_usize.into()); + + let field_tys = + Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } @@ -75,9 +120,50 @@ impl<'ctx> ListType<'ctx> { element_type: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - ListType::from_type(llvm_list, llvm_usize) + Self { ty: llvm_list, item: Some(element_type), llvm_usize } + } + + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + + Self { ty: llvm_list, item: None, llvm_usize } + } + + /// Creates an [`ListType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `item_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { + None + } else { + Some(ctx.get_llvm_type(generator, elem_type)) + }; + + Self { + ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), + item: llvm_elem_type, + llvm_usize, + } } /// Creates an [`ListType`] from a [`PointerType`]. @@ -85,30 +171,39 @@ impl<'ctx> ListType<'ctx> { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - ListType { ty: ptr_ty, llvm_usize } + let ctx = ptr_ty.get_context(); + + // We are just searching for the index off a field - Slot an arbitrary element type in. + let item_field_idx = + Self::fields(ctx.i8_type().into(), llvm_usize).index_of_field(|f| f.items); + let item = unsafe { + ptr_ty + .get_element_type() + .into_struct_type() + .get_field_type_at_index_unchecked(item_field_idx) + .into_pointer_type() + .get_element_type() + }; + let item = BasicTypeEnum::try_from(item).unwrap_or_else(|()| { + panic!( + "Expected BasicTypeEnum for list element type, got {}", + ptr_ty.get_element_type().print_to_string() + ) + }); + + ListType { ty: ptr_ty, item: Some(item), llvm_usize } } /// Returns the type of the `size` field of this `list` type. #[must_use] pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(1) - .map(BasicTypeEnum::into_int_type) - .unwrap() + self.llvm_usize } /// Returns the element type of this `list` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> Option> { + self.item } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. @@ -144,6 +239,73 @@ impl<'ctx> ListType<'ctx> { ) } + /// Allocates a [`ListValue`] on the stack using `item` of this [`ListType`] instance. + /// + /// The returned list will contain: + /// + /// - `data`: Allocated with `len` number of elements. + /// - `len`: Initialized to the value of `len` passed to this function. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let len = ctx.builder.build_int_z_extend(len, self.llvm_usize, "").unwrap(); + + // Generate a runtime assertion if allocating a non-empty list with unknown element type + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None && self.item.is_none() { + let len_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, len, self.llvm_usize.const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + len_eqz, + "0:AssertionError", + "Cannot allocate a non-empty list with unknown element type", + [None, None, None], + ctx.current_loc, + ); + } + + let plist = self.alloca_var(generator, ctx, name); + plist.store_size(ctx, generator, len); + + let item = self.item.unwrap_or(self.llvm_usize.into()); + plist.create_data(ctx, item, None); + + plist + } + + /// Convenience function for creating a list with zero elements. + /// + /// This function is preferred over [`ListType::construct`] if the length is known to always be + /// 0, as this function avoids injecting an IR assertion for checking if a non-empty untyped + /// list is being allocated. + /// + /// The returned list will contain: + /// + /// - `data`: Initialized to `(T*) 0`. + /// - `len`: Initialized to `0`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + let plist = self.alloca_var(generator, ctx, name); + + plist.store_size(ctx, generator, self.llvm_usize.const_zero()); + plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); + + plist + } + /// Converts an existing value into a [`ListValue`]. #[must_use] pub fn map_value( diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 4d6dcaf7..87781d11 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -5,6 +5,7 @@ use inkwell::{ types::{BasicTypeEnum, IntType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, }; +use itertools::Itertools; use crate::codegen::CodeGenContext; @@ -55,6 +56,20 @@ pub trait StructFields<'ctx>: Eq + Copy { { self.into_vec().into_iter() } + + /// Returns the field index of a field in this structure. + fn index_of_field(&self, name: impl FnOnce(&Self) -> StructField<'ctx, V>) -> u32 + where + V: BasicValue<'ctx> + TryFrom, Error = ()>, + { + let field_name = name(self).name; + self.index_of_field_name(field_name).unwrap() + } + + /// Returns the field index of a field with the given name in this structure. + fn index_of_field_name(&self, field_name: &str) -> Option { + self.iter().find_position(|(name, _)| *name == field_name).map(|(idx, _)| idx as u32) + } } /// A single field of an LLVM structure. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 549bfe3f..bd115a2d 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::ListType, + types::{structure::StructField, ListType}, {CodeGenContext, CodeGenerator}, }; @@ -42,48 +42,26 @@ impl<'ctx> ListValue<'ctx> { ListValue { value: ptr, llvm_usize, name } } + fn items_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).items + } + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Returns the pointer to the field storing the size of this `list`. - fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + self.items_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); + self.items_field(ctx).set(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element /// type `elem_ty` and `size`. /// - /// If `size` is [None], the size stored in the field of this instance is used instead. + /// If `size` is [None], the size stored in the field of this instance is used instead. If + /// `size` is resolved to `0` at runtime, `(T*) 0` will be assigned to `data`. pub fn create_data( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -114,6 +92,10 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } + fn len_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).len + } + /// Stores the `size` of this `list` into this instance. pub fn store_size( &self, @@ -123,22 +105,16 @@ impl<'ctx> ListValue<'ctx> { ) { debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); - let psize = self.ptr_to_size(ctx); - ctx.builder.build_store(psize, size).unwrap(); + self.len_field(ctx).set(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. - pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let psize = self.ptr_to_size(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.size"))) - .unwrap_or_default(); - - ctx.builder - .build_load(psize, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() + pub fn load_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> IntValue<'ctx> { + self.len_field(ctx).get(ctx, self.value, name) } } From 7d02f5833d5eea92545607f39e7368463987a987 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 13:56:08 +0800 Subject: [PATCH 16/80] [core] codegen: Implement Tuple{Type,Value} --- nac3core/src/codegen/builtin_fns.rs | 80 ++++-------- nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/tuple.rs | 184 +++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/tuple.rs | 85 +++++++++++++ 6 files changed, 303 insertions(+), 54 deletions(-) create mode 100644 nac3core/src/codegen/types/tuple.rs create mode 100644 nac3core/src/codegen/values/tuple.rs diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index b21e721e..32b95a75 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -14,7 +14,7 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, TupleType}, values::{ ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, @@ -1868,34 +1868,6 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( }) } -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: &[BasicValueEnum<'ctx>], -) -> PointerValue<'ctx> { - let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, *v).unwrap(); - } - } - out_ptr -} - /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -1973,10 +1945,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( None, ); - let q = q.as_base_value().into(); - let r = r.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[q, r]); - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + let q = q.as_base_value().as_basic_value_enum(); + let r = r.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) + .construct_from_objects(ctx, [q, r], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_svd` linalg function @@ -2031,12 +2004,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( None, ); - let u = u.as_base_value().into(); - let s = s.as_base_value().into(); - let vh = vh.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[u, s, vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + let u = u.as_base_value().as_basic_value_enum(); + let s = s.as_base_value().as_basic_value_enum(); + let vh = vh.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + .construct_from_objects(ctx, [u, s, vh], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_inv` linalg function @@ -2158,10 +2131,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( None, ); - let l = l.as_base_value().into(); - let u = u.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[l, u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + let l = l.as_base_value().as_basic_value_enum(); + let u = u.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) + .construct_from_objects(ctx, [l, u], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -2293,10 +2267,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( None, ); - let t = t.as_base_value().into(); - let z = z.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[t, z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + let t = t.as_base_value().as_basic_value_enum(); + let z = z.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) + .construct_from_objects(ctx, [t, z], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2337,8 +2312,9 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( None, ); - let h = h.as_base_value().into(); - let q = q.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[h, q]); - Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) + let h = h.as_base_value().as_basic_value_enum(); + let q = q.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) + .construct_from_objects(ctx, [h, q], None); + Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1e0fb268..2ce3c9ab 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -42,7 +42,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -574,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - ctx.struct_type(&fields, false).into() + TupleType::new(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 98bd43ba..0a31d6a5 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -28,11 +28,13 @@ use super::{ }; pub use list::*; pub use range::*; +pub use tuple::*; mod list; pub mod ndarray; mod range; pub mod structure; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs new file mode 100644 index 00000000..ccb63b4a --- /dev/null +++ b/nac3core/src/codegen/types/tuple.rs @@ -0,0 +1,184 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, StructType}, + values::BasicValueEnum, +}; +use itertools::Itertools; + +use super::ProxyType; +use crate::{ + codegen::{ + values::{ProxyValue, TupleValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct TupleType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> TupleType<'ctx> { + /// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not. + pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> { + Ok(()) + } + + /// Creates an LLVM type corresponding to the expected structure of a tuple. + #[must_use] + fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { + ctx.struct_type(tys, false) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + + /// Creates an [`TupleType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { + panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty)); + }; + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize } + } + + /// Creates an [`TupleType`] from a [`StructType`]. + #[must_use] + pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(struct_ty).is_ok()); + + TupleType { ty: struct_ty, llvm_usize } + } + + /// Returns the number of elements present in this [`TupleType`]. + #[must_use] + pub fn num_elements(&self) -> u32 { + self.ty.count_fields() + } + + /// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of + /// range. + #[must_use] + pub fn type_at_index(&self, index: u32) -> Option> { + if index < self.num_elements() { + Some(unsafe { self.type_at_index_unchecked(index) }) + } else { + None + } + } + + /// Returns the type of the tuple element at the given `index`. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid. + #[must_use] + pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> { + self.ty.get_field_type_at_index_unchecked(index) + } + + /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. + #[must_use] + pub fn construct( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name) + } + + /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of + /// objects. + #[must_use] + pub fn construct_from_objects>>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + objects: I, + name: Option<&'ctx str>, + ) -> >::Value { + let values = objects.into_iter().collect_vec(); + + assert_eq!(values.len(), self.num_elements() as usize); + assert!(values + .iter() + .enumerate() + .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); + + let mut value = self.construct(ctx, name); + for (i, val) in values.into_iter().enumerate() { + value.store_element(ctx, i as u32, val); + } + + value + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type Base = StructType<'ctx>; + type Value = TupleValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected struct type, got {llvm_ty:?}")) + } + } + + fn is_representable( + _generator: &G, + _ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: TupleType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 032f0417..c789fe0f 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -5,11 +5,13 @@ use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use range::*; +pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs new file mode 100644 index 00000000..5167e479 --- /dev/null +++ b/nac3core/src/codegen/values/tuple.rs @@ -0,0 +1,85 @@ +use inkwell::{ + types::IntType, + values::{BasicValue, BasicValueEnum, StructValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::TupleType, CodeGenContext}; + +#[derive(Copy, Clone)] +pub struct TupleValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> TupleValue<'ctx> { + /// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an + /// instance. + pub fn is_representable( + value: StructValue<'ctx>, + _llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + TupleType::is_representable(value.get_type()) + } + + /// Creates an [`TupleValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(value, llvm_usize).is_ok()); + + Self { value, llvm_usize, name } + } + + /// Stores a value into the tuple element at the given `index`. + pub fn store_element( + &mut self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + element: impl BasicValue<'ctx>, + ) { + assert_eq!(element.as_basic_value_enum().get_type(), unsafe { + self.get_type().type_at_index_unchecked(index) + }); + + let new_value = ctx + .builder + .build_insert_value(self.value, element, index, self.name.unwrap_or_default()) + .unwrap(); + self.value = new_value.into_struct_value(); + } + + /// Loads a value from the tuple element at the given `index`. + pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + ctx.builder + .build_extract_value( + self.value, + index, + &format!("{}[{{i}}]", self.name.unwrap_or("tuple")), + ) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type Base = StructValue<'ctx>; + type Type = TupleType<'ctx>; + + fn get_type(&self) -> Self::Type { + TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: TupleValue<'ctx>) -> Self { + value.as_base_value() + } +} From 5880f964bbf2b134690dff6a0d04c533cc45f2c9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 16 Dec 2024 15:26:18 +0800 Subject: [PATCH 17/80] [core] codegen/ndarray: Reimplement np_{zeros,ones,full,empty} Based on 792374fa: core/ndstrides: implement np_{zeros,ones,full,empty}. --- nac3core/src/codegen/numpy.rs | 291 +++--------------- nac3core/src/codegen/types/ndarray/factory.rs | 146 +++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- nac3core/src/codegen/values/ndarray/mod.rs | 18 ++ nac3core/src/codegen/values/ndarray/shape.rs | 152 +++++++++ 6 files changed, 374 insertions(+), 241 deletions(-) create mode 100644 nac3core/src/codegen/types/ndarray/factory.rs create mode 100644 nac3core/src/codegen/values/ndarray/shape.rs diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9b5af0f1..9c57919c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -3,7 +3,6 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::Itertools; use nac3parser::ast::{Operator, StrRef}; @@ -19,17 +18,28 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ndarray::NDArrayType, ListType, ProxyType}, + types::{ + ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, + }, + ListType, ProxyType, + }, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, - ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{extract_ndims, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::{ magic_methods::Binop, typedef::{FunSignature, Type, TypeEnum}, @@ -174,132 +184,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let shape = (0..ndims) - .map(|dim_i| { - ctx.builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .map(BasicValueEnum::into_int_value) - .map(|v| { - ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap() - }) - .unwrap() - }) - .collect_vec(); - - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - let shape_int = - ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap(); - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -529,107 +413,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - })?; - - Ok(ndarray) -} - /// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( generator: &G, @@ -1752,8 +1535,15 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_empty(generator, context, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1770,8 +1560,15 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_zeros(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.ones`. @@ -1788,8 +1585,15 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_ones(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.full`. @@ -1809,8 +1613,15 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + Ok(ndarray.as_base_value()) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs new file mode 100644 index 00000000..13aae8cd --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -0,0 +1,146 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use super::NDArrayType; +use crate::{ + codegen::{ + irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// Get the zero value in `np.zeros()` of a `dtype`. +pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayType<'ctx> { + /// Create an ndarray like + /// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html). + pub fn construct_numpy_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, name); + + // Validate `shape` + irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape); + + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + ndarray + } + + /// Create an ndarray like + /// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html). + pub fn construct_numpy_full( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + fill_value: BasicValueEnum<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_numpy_empty(generator, ctx, shape, name); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like + /// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html). + pub fn construct_numpy_zeros( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_zero_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } + + /// Create an ndarray like + /// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html). + pub fn construct_numpy_ones( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_one_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 3886ce84..89241618 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -25,6 +25,7 @@ pub use indexing::*; pub use nditer::*; mod contiguous; +pub mod factory; mod indexing; mod nditer; diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 7ce8ed79..9b71693a 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,8 +163,13 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { + assert!( + ndarray.get_type().ndims().is_some(), + "NDIter requires ndims of NDArray to be known." + ); + let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = ndarray.load_ndims(ctx); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index e47876c6..ffde76c9 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -23,6 +23,7 @@ pub use nditer::*; mod contiguous; mod indexing; mod nditer; +pub mod shape; mod view; /// Proxy type for accessing an `NDArray` value in LLVM. @@ -397,6 +398,23 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); } + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |_, ctx, _, nditer| { + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs new file mode 100644 index 00000000..190a1e4f --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -0,0 +1,152 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use crate::{ + codegen::{ + stmt::gen_for_callback_incrementing, + types::{ListType, TupleType}, + values::{ + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to +/// `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), +) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let zero = llvm_usize.const_zero(); + let one = llvm_usize.const_int(1, false); + + // The result `list` to return. + match &*ctx.unifier.get_ty_immutable(input_seq_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_pointer_value(), None); + + let len = input_seq.load_size(ctx, None); + // TODO: Find a way to remove this mid-BB allocation + let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap(); + let result = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(result, len, None), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_callback_incrementing( + generator, + ctx, + None, + zero, + (len, false), + |generator, ctx, _, i| { + // Load the i-th int32 in the input sequence + let int = unsafe { + input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value() + }; + + // Cast to SizeT + let int = + ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + // Store + unsafe { result.set_typed_unchecked(ctx, generator, &i, int) }; + + Ok(()) + }, + one, + ) + .unwrap(); + + result + } + + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_struct_value(), None); + + let len = input_seq.get_type().num_elements(); + + let result = generator + .gen_array_var_alloc( + ctx, + llvm_usize.into(), + llvm_usize.const_int(u64::from(len), false), + None, + ) + .unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + for i in 0..input_seq.get_type().num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_seq.load_element(ctx, i).into_int_value(); + let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + unsafe { + result.set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(u64::from(i), false), + int, + ); + } + } + + result + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + let input_int = input_seq.into_int_value(); + + let len = one; + let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let int = + ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap(); + + // Storing into result[0] + unsafe { + result.set_typed_unchecked(ctx, generator, &zero, int); + } + + result + } + + _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)), + } +} From 26f1428739568968aa335e37bab33909ee810931 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 14:21:13 +0800 Subject: [PATCH 18/80] [core] codegen: Refactor len() Based on 54a842a9: core/ndstrides: implement len(ndarray) & refactor len() --- nac3core/src/codegen/builtin_fns.rs | 61 ++++++++++++----------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 32b95a75..54650ab3 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -14,9 +14,9 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::{ndarray::NDArrayType, TupleType}, + types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, + ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -55,42 +55,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TTuple { .. } => { + let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_struct_value(), None); + llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) .map_value(arg.into_pointer_value(), None); - - let ndims = arg.shape().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + ctx.builder + .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") + .unwrap() } - _ => codegen_unreachable!(ctx), + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_pointer_value(), None); + ctx.builder + .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") + .unwrap() + } + + _ => unsupported_type(ctx, "len", &[arg_ty]), } }) } From fadadd7505c99b3accbd942200aa657d64ad28b1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 20 Aug 2024 14:51:40 +0800 Subject: [PATCH 19/80] [core] codegen/ndarray: Reimplement np_array() Based on 8f0084ac: core/ndstrides: implement np_array() It also checks for inconsistent dimensions if the input is a list. e.g., rejecting `[[1.0, 2.0], [3.0]]`. However, currently only `np_array(, copy=False)` and `np_array (, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(, copy=False)` behaves like NumPy's `np.array(, copy=None)`. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/list.hpp | 15 + nac3core/irrt/irrt/ndarray/array.hpp | 132 ++++++ nac3core/src/codegen/irrt/ndarray/array.rs | 80 ++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/numpy.rs | 460 +------------------- nac3core/src/codegen/types/ndarray/array.rs | 245 +++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/values/list.rs | 19 +- nac3core/src/codegen/values/ndarray/mod.rs | 2 +- 10 files changed, 512 insertions(+), 445 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/array.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/array.rs create mode 100644 nac3core/src/codegen/types/ndarray/array.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 0d069869..57f60d52 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,3 +9,4 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/array.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index 28543945..1edfe498 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -2,6 +2,21 @@ #include "irrt/int_types.hpp" #include "irrt/math_util.hpp" +#include "irrt/slice.hpp" + +namespace { +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template +struct List { + uint8_t* items; + SizeT len; +}; +} // namespace extern "C" { // Handle list assignment and dropping part of the list when diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 00000000..126669e7 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/list.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::array { +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], + * [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the + * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because + * of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List* list, SizeT ndims, SizeT* shape) { + if (shape[axis] == -1) { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } else { + // Dimension is specified. Check. + if (shape[axis] != list->len) { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) { + // `list` has type `list[ItemType]` + // Do nothing + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + for (SizeT i = 0; i < list->len; i++) { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template +void set_and_validate_list_shape(List* list, SizeT ndims, SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT* index, List* list, NDArray* ndarray) { + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) { + if (!ndarray::basic::is_c_contiguous(ndarray)) { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t* dst = static_cast(ndarray->data) + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + + for (SizeT i = 0; i < list->len; i++) { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template +void write_list_to_array(List* list, NDArray* ndarray) { + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace ndarray::array +} // namespace + +extern "C" { +using namespace ndarray::array; + +void __nac3_ndarray_array_set_and_validate_list_shape(List* list, int32_t ndims, int32_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_set_and_validate_list_shape64(List* list, int64_t ndims, int64_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_write_list_to_array(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} + +void __nac3_ndarray_array_write_list_to_array64(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs new file mode 100644 index 00000000..931b66cb --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -0,0 +1,80 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_array_set_and_validate_list_shape`. +/// +/// Deduces the target shape of the `ndarray` from the provided `list`, raising an exception if +/// there is any issue with the resultant `shape`. +/// +/// `shape` must be pre-allocated by the caller of this function to `[usize; ndims]`, and must be +/// initialized to all `-1`s. +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndims: IntValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + assert_eq!(ndims.get_type(), llvm_usize); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_set_and_validate_list_shape", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_array_write_list_to_array`. +/// +/// Copies the contents stored in `list` into `ndarray`. +/// +/// The `ndarray` must fulfill the following preconditions: +/// +/// - `ndarray.itemsize`: Must be initialized. +/// - `ndarray.ndims`: Must be initialized. +/// - `ndarray.shape`: Must be initialized. +/// - `ndarray.data`: Must be allocated and contiguous. +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) { + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_write_list_to_array", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndarray.as_base_value().into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56017c94..307ec6bb 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -16,10 +16,12 @@ use crate::codegen::{ }, CodeGenContext, CodeGenerator, }; +pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +mod array; mod basic; mod indexing; mod iter; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9c57919c..09d848dc 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,7 +1,7 @@ use inkwell::{ - types::{BasicType, BasicTypeEnum, PointerType}, + types::BasicType, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + IntPredicate, OptimizationLevel, }; use nac3parser::ast::{Operator, StrRef}; @@ -18,12 +18,9 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ - ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, - ListType, ProxyType, + types::ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, }, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, @@ -35,14 +32,10 @@ use super::{ }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{ - helper::{extract_ndims, PrimDef}, - numpy::unpack_ndarray_var_tys, - DefinitionId, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{FunSignature, Type}, }, }; @@ -413,394 +406,6 @@ where Ok(res) } -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type().unwrap(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type().unwrap(); - - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - let offset = ctx - .builder - .build_int_mul( - offset, - ctx.builder - .build_int_truncate_or_bit_cast( - dst_arr.get_type().element_type().size_of().unwrap(), - offset.get_type(), - "", - ) - .unwrap(), - "", - ) - .unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_pointer_value( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - let sizeof_elem = - ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_pointer_value( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_elem_ty, - None, - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_representable(object, llvm_usize).is_ok()); - let object = ListValue::from_pointer_value(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = - ListValue::from_pointer_value(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(next_dim, llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_pointer_value( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - /// LLVM-typed implementation for generating the implementation for `ndarray.eye`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1635,26 +1240,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1670,28 +1255,17 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let copy = generator.bool_to_i1(context, copy_arg.into_int_value()); + let ndarray = NDArrayType::from_unifier_type(generator, context, fun.0.ret) + .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.eye`. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs new file mode 100644 index 00000000..87cd002a --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -0,0 +1,245 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, IntValue}, + AddressSpace, +}; + +use crate::{ + codegen::{ + irrt, + stmt::gen_if_else_expr_callback, + types::{ndarray::NDArrayType, ListType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array()`. +fn get_list_object_dtype_and_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + list_ty: Type, +) -> (BasicTypeEnum<'ctx>, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list_ty); + let ndims = arraylike_get_ndims(&mut ctx.unifier, list_ty); + + (ctx.get_llvm_type(generator, dtype), ndims) +} + +impl<'ctx> NDArrayType<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn construct_numpy_array_from_list_copy_true_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert_eq!(dtype, self.dtype); + + let list_value = list.as_i8_list(generator, ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = self.llvm_usize.const_int(ndims_int, false); + let shape = ctx.builder.build_array_alloca(self.llvm_usize, ndims, "").unwrap(); + let shape = ArraySliceValue::from_ptr_val(shape, ndims, None); + let shape = TypedArrayLikeAdapter::from( + shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + irrt::ndarray::call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, &shape, + ); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, name); + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + // Copy all contents from the list. + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( + generator, ctx, list_value, ndarray, + ); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn construct_numpy_array_from_list_copy_none_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + if ndims == 1 { + // `list` is not nested + assert_eq!(ndims, 1); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert_eq!(dtype, self.dtype); + + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + .construct_uninitialized(generator, ctx, name); + + // Set data + let data = ctx + .builder + .build_pointer_cast(list.data().base_ptr(ctx, generator), llvm_pi8, "") + .unwrap(); + ndarray.store_data(ctx, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.shape(); + let list_len = list.load_size(ctx, None); + unsafe { + shape.set_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), list_len); + } + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(generator, ctx); + + ndarray + } else { + // `list` is nested, copy + self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn construct_numpy_array_list_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_none_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn construct_numpy_array_ndarray_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(ndarray.get_type().dtype, self.dtype); + assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self + .ndims + .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.as_base_value())) + }, + |_generator, _ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + ndarray.get_type().map_value(ndarray_val, name) + } + + /// Create a new ndarray like + /// [`np.array()`](https://numpy.org/doc/stable/reference/generated/numpy.array.html). + /// + /// Note that the returned [`NDArrayValue`] may have fewer dimensions than is specified by this + /// instance. Use [`NDArrayValue::atleast_nd`] on the returned value if an `ndarray` instance + /// with the exact number of dimensions is needed. + pub fn construct_numpy_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + match &*ctx.unifier.get_ty_immutable(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) + } + + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 89241618..17bb6adc 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -24,6 +24,7 @@ pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod array; mod contiguous; pub mod factory; mod indexing; diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index bd115a2d..c497f8f8 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::{structure::StructField, ListType}, + types::{structure::StructField, ListType, ProxyType}, {CodeGenContext, CodeGenerator}, }; @@ -116,6 +116,23 @@ impl<'ctx> ListValue<'ctx> { ) -> IntValue<'ctx> { self.len_field(ctx).get(ctx, self.value, name) } + + /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. + #[must_use] + pub fn as_i8_list( + &self, + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> ListValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + + Self::from_pointer_value( + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + self.llvm_usize, + self.name, + ) + } } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index ffde76c9..d4a460a5 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -173,7 +173,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { let data = ctx .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") From acb437919d951e0c14a468a55d2908282cc16044 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 18:01:12 +0800 Subject: [PATCH 20/80] [core] codegen/ndarray: Reimplement np_{eye,identity} Based on fa047d50: core/ndstrides: implement np_identity() and np_eye() --- nac3core/src/codegen/numpy.rs | 107 ++++++------------ nac3core/src/codegen/types/ndarray/factory.rs | 94 ++++++++++++++- 2 files changed, 126 insertions(+), 75 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 09d848dc..703d03ec 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,10 +18,7 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, + types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, @@ -406,55 +403,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let nrows = context + .builder + .build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ncols = context + .builder + .build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let offset = context + .builder + .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") + .unwrap(); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.identity`. @@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let n = context + .builder + .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_identity(generator, context, dtype, n, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.copy`. diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 13aae8cd..300167f7 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -1,4 +1,7 @@ -use inkwell::values::{BasicValueEnum, IntValue}; +use inkwell::{ + values::{BasicValueEnum, IntValue}, + IntPredicate, +}; use super::NDArrayType; use crate::{ @@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( } /// Get the one value in `np.ones()` of a `dtype`. -pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, @@ -143,4 +146,91 @@ impl<'ctx> NDArrayType<'ctx> { let fill_value = ndarray_one_value(generator, ctx, dtype); self.construct_numpy_full(generator, ctx, shape, fill_value, name) } + + /// Create an ndarray like + /// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html). + #[allow(clippy::too_many_arguments)] + pub fn construct_numpy_eye( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + assert_eq!(nrows.get_type(), self.llvm_usize); + assert_eq!(ncols.get_type(), self.llvm_usize); + assert_eq!(offset.get_type(), self.llvm_usize); + + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name); + + // Create data and make the matrix like look np.eye() + unsafe { + ndarray.create_data(generator, ctx); + } + ndarray + .foreach(generator, ctx, |generator, ctx, _, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + let indices = nditer.get_indices(); + + let row_i = unsafe { + indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None) + }; + let col_i = unsafe { + indices.get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(1, false), + None, + ) + }; + + let be_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + ctx.builder.build_int_add(row_i, offset, "").unwrap(), + col_i, + "", + ) + .unwrap(); + let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like + /// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html). + pub fn construct_numpy_identity( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let offset = self.llvm_usize.const_zero(); + self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name) + } } From 9ffa2d6552e5052aba16bdba9b9f081efa138c68 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 09:53:00 +0800 Subject: [PATCH 21/80] [core] codegen/ndarray: Reimplement np_{copy,fill} Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy() --- nac3core/src/codegen/numpy.rs | 57 ++++------------------ nac3core/src/codegen/values/ndarray/mod.rs | 2 + 2 files changed, 11 insertions(+), 48 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 703d03ec..2f899f95 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>( assert!(args.is_empty()); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_copy_impl( - generator, - context, - this_elem_ty, - llvm_this_ty.map_value(this_arg.into_pointer_value(), None), - ) - .map(NDArrayValue::into) + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.fill`. @@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>( assert_eq!(args.len(), 1); let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_fill_flattened( - generator, - context, - llvm_this_ty.map_value(this_arg, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + this.fill(generator, context, value_arg); Ok(()) } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d4a460a5..2907445e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>, ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. self.foreach(generator, ctx, |_, ctx, _, nditer| { let p = nditer.get_pointer(ctx); ctx.builder.build_store(p, value).unwrap(); From 12358c57b1a17592e18fa6f2acd161b87b22269b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:28:56 +0800 Subject: [PATCH 22/80] [core] codegen/ndarray: Implement np_{shape,strides} Based on 40c24486: core/ndstrides: implement np_shape() and np_strides() These functions are not important, but they are handy for debugging. `np.strides()` is not an actual NumPy function, but `ndarray.strides` is used. --- nac3core/src/codegen/values/ndarray/mod.rs | 86 +++++++++++++++++-- nac3core/src/toplevel/builtins.rs | 53 ++++++++++++ nac3core/src/toplevel/helper.rs | 8 ++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 41 ++++++++- nac3standalone/demo/interpret_demo.py | 4 + 10 files changed, 193 insertions(+), 13 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 2907445e..951792fb 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -1,19 +1,23 @@ +use std::iter::repeat_n; + use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }; use crate::codegen::{ irrt, llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, stmt::gen_for_callback_incrementing, type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField}, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; pub use contiguous::*; @@ -417,13 +421,85 @@ impl<'ctx> NDArrayValue<'ctx> { .unwrap(); } + /// Create the shape tuple of this ndarray like + /// [`np.shape()`](https://numpy.org/doc/stable/reference/generated/numpy.shape.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.shape().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + + /// Create the strides tuple of this ndarray like + /// [`.strides`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.strides().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { self.ndims.map(|ndims| ndims == 0) } - /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. /// Otherwise, do nothing and return the ndarray itself. // TODO: Rename to get_unsized_element pub fn split_unsized( diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 36bd85d1..ac3fa08f 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,6 +14,7 @@ use crate::{ builtin_fns, numpy::*, stmt::exn_constructor, + types::ndarray::NDArrayType, values::{ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, @@ -368,6 +369,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1242,6 +1247,54 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.as_base_value().into())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 71c1859b..75a7eabc 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,10 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpShape, + FunNpStrides, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -240,6 +244,10 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 93f2096b..9313448e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index d3301d00..0aa21de1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 911426b9..2490cc75 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index d60daf83..a7230f4d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 517f6846..871a2f89 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 6068f630..8f1c54fc 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::{ cmp::max, collections::{HashMap, HashSet}, convert::{From, TryInto}, - iter::once, + iter::{once, repeat_n}, sync::Arc, }; @@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..5bcf4bb5 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray property getters + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf From 132ba1942f018981d201e34520bea070d6137e5b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:58:38 +0800 Subject: [PATCH 23/80] [core] toplevel: Implement np_size Based on 2c1030d1: core/ndstrides: implement np_size() --- nac3core/src/toplevel/builtins.rs | 35 +++++++++++++++++-- nac3core/src/toplevel/helper.rs | 2 ++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +-- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +-- nac3standalone/demo/interpret_demo.py | 1 + 8 files changed, 43 insertions(+), 9 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ac3fa08f..de5a278e 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -369,7 +369,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpShape | PrimDef::FunNpStrides => { + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { self.build_ndarray_property_getter_function(prim) } @@ -1248,7 +1248,10 @@ impl<'a> BuiltinBuilder<'a> { } fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1257,6 +1260,34 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let size = ctx + .builder + .build_int_truncate_or_bit_cast( + ndarray.size(generator, ctx), + ctx.ctx.i32_type(), + "", + ) + .unwrap(); + Ok(Some(size.into())) + }), + ), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { // The function signatures of `np_shape` an `np_size` are the same. // Mixed together for convenience. diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 75a7eabc..10e77422 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -55,6 +55,7 @@ pub enum PrimDef { FunNpIdentity, // NumPy ndarray property getters + FunNpSize, FunNpShape, FunNpStrides, @@ -245,6 +246,7 @@ impl PrimDef { PrimDef::FunNpIdentity => fun("np_identity", None), // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 9313448e..4650fbbf 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 0aa21de1..b67596d8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 2490cc75..08f254f5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index a7230f4d..ce3b02ed 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 871a2f89..b053b814 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 5bcf4bb5..ca17c3da 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray property getters + module.np_size = np.size module.np_shape = np.shape module.np_strides = lambda ndarray: ndarray.strides From aae41eef6a5de558b4202f076a4cb1785aaaf32b Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 11:05:45 +0800 Subject: [PATCH 24/80] [core] toplevel: Add view functions category Based on 9e0f636d: core: categorize np_{transpose,reshape} as 'view functions' --- nac3core/src/toplevel/builtins.rs | 108 +++++++++--------- nac3core/src/toplevel/helper.rs | 12 +- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3standalone/demo/interpret_demo.py | 6 +- 8 files changed, 72 insertions(+), 68 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index de5a278e..276c00c7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -373,6 +373,10 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } + PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_ndarray_view_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -438,10 +442,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { - self.build_np_sp_ndarray_function(prim) - } - PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -1326,6 +1326,55 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build np/sp functions that take as input `NDArray` only + fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpTranspose => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + }), + ), + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpReshape => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }), + ), + + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1813,57 +1862,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build np/sp functions that take as input `NDArray` only - fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); - - match prim { - PrimDef::FunNpTranspose => { - let ndarray_ty = self.unifier.get_fresh_var_with_range( - &[self.ndarray_num_ty], - Some("T".into()), - None, - ); - create_fn_by_codegen( - self.unifier, - &into_var_map([ndarray_ty]), - prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ) - } - - // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and - // the `param_ty` for `create_fn_by_codegen`. - // - // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking - // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], - // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), - - _ => unreachable!(), - } - } - /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 10e77422..9313b13b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -59,6 +59,10 @@ pub enum PrimDef { FunNpShape, FunNpStrides, + // NumPy ndarray view functions + FunNpTranspose, + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -106,8 +110,6 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunNpTranspose, - FunNpReshape, // Linalg functions FunNpDot, @@ -250,6 +252,10 @@ impl PrimDef { PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), + // NumPy NDArray view functions + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), @@ -297,8 +303,6 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunNpTranspose => fun("np_transpose", None), - PrimDef::FunNpReshape => fun("np_reshape", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 4650fbbf..b03b3616 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b67596d8..b4df49c9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 08f254f5..65a6a8ac 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index ce3b02ed..cfedf1f6 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index b053b814..9a77a21c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", ] diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index ca17c3da..56c6126d 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray view functions + module.np_transpose = np.transpose + module.np_reshape = np.reshape + # NumPy NDArray property getters module.np_size = np.size module.np_shape = np.shape @@ -223,8 +227,6 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter - module.np_transpose = np.transpose - module.np_reshape = np.reshape # SciPy Math functions module.sp_spec_erf = special.erf From 8d975b5ff3ecf04c71ba7971aac8c1f7c91dab81 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 11:40:23 +0800 Subject: [PATCH 25/80] [core] codegen/ndarray: Implement np_reshape Based on 926e7e93: core/ndstrides: implement np_reshape() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/reshape.hpp | 97 +++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/irrt/ndarray/reshape.rs | 40 +++ nac3core/src/codegen/numpy.rs | 334 +----------------- nac3core/src/codegen/values/ndarray/view.rs | 77 +++- nac3core/src/toplevel/builtins.rs | 59 +++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3standalone/demo/src/ndarray.py | 55 ++- 13 files changed, 308 insertions(+), 373 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/reshape.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/reshape.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 57f60d52..bb7fb3d4 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,5 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" -#include "irrt/ndarray/array.hpp" \ No newline at end of file +#include "irrt/ndarray/array.hpp" +#include "irrt/ndarray/reshape.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 00000000..b2ad2a5b --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::reshape { +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template +void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) { + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) { + SizeT dim = new_shape[axis_i]; + if (dim < 0) { + if (dim == -1) { + if (neg1_exists) { + // Multiple `-1` found. Throw an error. + raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, + NO_PARAM, NO_PARAM); + } else { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } else { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, + NO_PARAM); + } + } else { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) { + // Let `x` be the unknown dimension + // Solve `x * = ` + if (new_size == 0 && size == 0) { + // `x` has infinitely many solutions + can_reshape = false; + } else if (new_size == 0 && size != 0) { + // `x` has no solutions + can_reshape = false; + } else if (size % new_size != 0) { + // `x` has no integer solutions + can_reshape = false; + } else { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } else { + can_reshape = (new_size == size); + } + + if (!can_reshape) { + raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, + NO_PARAM); + } +} +} // namespace ndarray::reshape +} // namespace + +extern "C" { +void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} + +void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 307ec6bb..f67566b9 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -20,11 +20,13 @@ pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +pub use reshape::*; mod array; mod basic; mod indexing; mod iter; +mod reshape; /// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs new file mode 100644 index 00000000..32de2fa1 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -0,0 +1,40 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ArrayLikeValue, ArraySliceValue}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`. +/// +/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(, new_shape)`, raising an +/// assertion if multiple dimensions are unknown (`-1`). +pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + new_ndims: IntValue<'ctx>, + new_shape: ArraySliceValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(size.get_type(), llvm_usize); + assert_eq!(new_ndims.get_type(), llvm_usize); + assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_reshape_resolve_and_check_new_shape", + ); + infer_and_call_function( + ctx, + &name, + None, + &[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 2f899f95..ce4e2059 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,9 +21,9 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -134,46 +134,6 @@ where Ok(ndarray) } -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - } - - let llvm_dtype = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64)) - .construct_dyn_shape(generator, ctx, shape, None); - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -1455,294 +1415,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( } } -/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. -/// -/// * `x1` - `NDArray` to reshape. -/// * `shape` - The `shape` parameter used to construct the new `NDArray`. -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` -/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// -/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. -pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - shape: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_reshape"; - let (x1_ty, x1) = x1; - let (_, shape) = shape; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = n1.size(generator, ctx); - - let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); - ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); - - let out = match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - // Check for -1 in dimensions - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_list.load_size(ctx, None), false), - |generator, ctx, _, idx| { - let ele = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - ele, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_neg_value = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_neg_value = ctx - .builder - .build_int_add( - num_neg_value, - llvm_usize.const_int(1, false), - "", - ) - .unwrap(); - ctx.builder.build_store(num_neg, num_neg_value).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_value = - ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_value = - ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); - ctx.builder.build_store(acc, acc_value).unwrap(); - Ok(None) - }, - )?; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - // Generate the output shape by filling -1 with `rem` - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, _| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - let dim = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` - - let ndims = shape_tuple.get_type().count_fields(); - // Check for -1 in dims - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_negs = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_negs = ctx - .builder - .build_int_add(num_negs, llvm_usize.const_int(1, false), "") - .unwrap(); - ctx.builder.build_store(num_neg, num_negs).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); - ctx.builder.build_store(acc, acc_val).unwrap(); - Ok(None) - }, - )?; - } - - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - let mut shape = Vec::with_capacity(ndims as usize); - - // Reconstruct shape filling negatives with rem - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - let dim = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value(); - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` - let shape_int = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - shape_int, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(n_sz)), - |_, ctx| { - Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) - }, - )? - .unwrap() - .into_int_value(); - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } - .unwrap(); - - // Only allow one dimension to be negative - let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "can only specify one unknown dimension", - [None, None, None], - ctx.current_loc, - ); - - // The new shape must be compatible with the old shape - let out_sz = out.size(generator, ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), - "0:ValueError", - "cannot reshape array of size {0} into provided shape of size {1}", - [Some(n_sz), Some(out_sz), None], - ctx.current_loc, - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 70a9d659..8334d379 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,9 +1,16 @@ use std::iter::{once, repeat_n}; +use inkwell::values::IntValue; use itertools::Itertools; use crate::codegen::{ - values::ndarray::{NDArrayValue, RustNDIndex}, + irrt, + stmt::gen_if_callback, + types::ndarray::NDArrayType, + values::{ + ndarray::{NDArrayValue, RustNDIndex}, + ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, + }, CodeGenContext, CodeGenerator, }; @@ -33,4 +40,72 @@ impl<'ctx> NDArrayValue<'ctx> { *self } } + + /// Create a reshaped view on this ndarray like + /// [`np.reshape()`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html). + /// + /// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be + /// modified as a result. + /// + /// If reshape without copying is impossible, this function will allocate a new ndarray and copy + /// contents. + /// + /// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`]. + /// * `new_shape` - The target shape to do `np.reshape()`. + #[must_use] + pub fn reshape_or_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + new_ndims: u64, + new_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert_eq!(new_shape.element_type(ctx, generator), self.llvm_usize.into()); + + // TODO: The current criterion for whether to do a full copy or not is by checking + // `is_c_contiguous`, but this is not optimal - there are cases when the ndarray is + // not contiguous but could be reshaped without copying data. Look into how numpy does + // it. + + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + .construct_uninitialized(generator, ctx, None); + dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); + + // Resolve negative indices + let size = self.size(generator, ctx); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_shape = dst_ndarray.shape(); + irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( + generator, + ctx, + size, + dst_ndims, + dst_shape.as_slice_value(ctx, generator), + ); + + gen_if_callback( + generator, + ctx, + |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |generator, ctx| { + // Reshape is possible without copying + dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); + + Ok(()) + }, + |generator, ctx| { + // Reshape is impossible without copying + unsafe { + dst_ndarray.create_data(generator, ctx); + } + dst_ndarray.copy_data_from(generator, ctx, *self); + + Ok(()) + }, + ) + .unwrap(); + + dst_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 276c00c7..db7acaf3 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -5,8 +5,10 @@ use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ - helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails}, - numpy::make_ndarray_ty, + helper::{ + debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + }, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, }; use crate::{ @@ -15,7 +17,7 @@ use crate::{ numpy::*, stmt::exn_constructor, types::ndarray::NDArrayType, - values::{ProxyValue, RangeValue}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -193,7 +195,6 @@ struct BuiltinBuilder<'a> { ndarray_float: Type, ndarray_float_2d: Type, - ndarray_num_ty: Type, float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, @@ -307,7 +308,6 @@ impl<'a> BuiltinBuilder<'a> { ndarray_float, ndarray_float_2d, - ndarray_num_ty, float_or_ndarray_ty, float_or_ndarray_var_map, @@ -1356,20 +1356,41 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), + PrimDef::FunNpReshape => { + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[ + (in_ndarray_ty.ty, "x"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding + ], + Box::new(move |ctx, _, fun, args, generator| { + let ndarray_ty = fun.0.args[0].ty; + let ndarray_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let shape_ty = fun.0.args[1].ty; + let shape_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray_val.into_pointer_value(), None); + + let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); + + // The ndims after reshaping is gotten from the return type of the call. + let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + }), + ) + } _ => unreachable!(), } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index b03b3616..f5349890 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b4df49c9..05408683 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 65a6a8ac..029815aa 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index cfedf1f6..fc8f1aaf 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a77a21c..34e78a2b 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", ] diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index d42f3b93..a82cfbad 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,13 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_4(n: ndarray[float, Literal[4]]): + for x in range(len(n)): + for y in range(len(n[x])): + for z in range(len(n[x][y])): + for w in range(len(n[x][y][z])): + output_float64(n[x][y][z][w]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -197,6 +204,38 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_reshape(): + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x = np_reshape(w, (1, 2, 1, -1)) + y = np_reshape(x, [2, -1]) + z = np_reshape(y, 10) + + output_int32(np_shape(w)[0]) + output_ndarray_float_1(w) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_int32(np_shape(x)[2]) + output_int32(np_shape(x)[3]) + output_ndarray_float_4(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_ndarray_float_1(z) + + x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) + x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + + output_int32(np_shape(x1)[0]) + output_ndarray_int32_1(x1) + + output_int32(np_shape(x2)[0]) + output_int32(np_shape(x2)[1]) + output_ndarray_int32_2(x2) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1448,19 +1487,6 @@ def test_ndarray_transpose(): output_ndarray_float_2(x) output_ndarray_float_2(y) -def test_ndarray_reshape(): - w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - x = np_reshape(w, (1, 2, 1, -1)) - y = np_reshape(x, [2, -1]) - z = np_reshape(y, 10) - - x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) - x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) - - output_ndarray_float_1(w) - output_ndarray_float_2(y) - output_ndarray_float_1(z) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1592,6 +1618,8 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_reshape() + test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar() @@ -1756,7 +1784,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_transpose() - test_ndarray_reshape() test_ndarray_dot() test_ndarray_cholesky() From 43e440d2fda7857f19544d61e2dccee21b5a79eb Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 15:23:41 +0800 Subject: [PATCH 26/80] [core] codegen/ndarray: Reimplement broadcasting Based on 9359ed96: core/ndstrides: implement broadcasting & np_broadcast_to() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/broadcast.hpp | 165 ++++++++++++ .../src/codegen/irrt/ndarray/broadcast.rs | 82 ++++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/types/ndarray/broadcast.rs | 176 +++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 16 ++ .../src/codegen/values/ndarray/broadcast.rs | 248 ++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 2 + nac3core/src/toplevel/builtins.rs | 24 +- nac3core/src/toplevel/helper.rs | 2 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 2 +- nac3standalone/demo/interpret_demo.py | 1 + nac3standalone/demo/src/ndarray.py | 24 ++ 18 files changed, 748 insertions(+), 13 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/broadcast.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/types/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/values/ndarray/broadcast.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index bb7fb3d4..09844d3b 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -10,4 +10,5 @@ #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" -#include "irrt/ndarray/reshape.hpp" \ No newline at end of file +#include "irrt/ndarray/reshape.hpp" +#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..6e54b1ca --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray::broadcast { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) { + if (src_ndims > target_ndims) { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry* shapes, SizeT dst_ndims, SizeT* dst_shape) { + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1 || entry_dim == dst_dim) { + // Do nothing + } else { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + +#ifdef IRRT_DEBUG_ASSERT + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +#endif +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(const NDArray* src_ndarray, NDArray* dst_ndarray) { + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } else { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace ndarray::broadcast +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, + int32_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, + int64_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs new file mode 100644 index 00000000..cb1ecd4c --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -0,0 +1,82 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + types::{ndarray::ShapeEntryType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_broadcast_to`. +/// +/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`. +/// +/// `dst_ndarray` must meet the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. +/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + infer_and_call_function( + ctx, + &name, + None, + &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_broadcast_shapes`. +/// +/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`, +/// writing the result to `dst_shape`. +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + num_shape_entries: IntValue<'ctx>, + shape_entries: ArraySliceValue<'ctx>, + dst_ndims: IntValue<'ctx>, + dst_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(num_shape_entries.get_type(), llvm_usize); + assert!(ShapeEntryType::is_type( + generator, + ctx.ctx, + shape_entries.base_ptr(ctx, generator).get_type() + ) + .is_ok()); + assert_eq!(dst_ndims.get_type(), llvm_usize); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + infer_and_call_function( + ctx, + &name, + None, + &[ + num_shape_entries.into(), + shape_entries.base_ptr(ctx, generator).into(), + dst_ndims.into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index f67566b9..c640042b 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -18,12 +18,14 @@ use crate::codegen::{ }; pub use array::*; pub use basic::*; +pub use broadcast::*; pub use indexing::*; pub use iter::*; pub use reshape::*; mod array; mod basic; +mod broadcast; mod indexing; mod iter; mod reshape; diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs new file mode 100644 index 00000000..5ee28454 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -0,0 +1,176 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::codegen::{ + types::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, + }, + values::{ndarray::ShapeEntryValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ShapeEntryType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ShapeEntryStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> ShapeEntryType<'ctx> { + /// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_ndarray_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ShapeEntryStructFields<'ctx> { + ShapeEntryStructFields::new(ctx, llvm_usize) + } + + /// See [`ShapeEntryStructFields::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> { + Self::fields(ctx, self.llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_ty = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_ty, llvm_usize } + } + + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type Base = PointerType<'ctx>; + type Value = ShapeEntryValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ShapeEntryType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 17bb6adc..316d0f33 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -20,11 +20,13 @@ use crate::{ toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; mod array; +mod broadcast; mod contiguous; pub mod factory; mod indexing; @@ -118,6 +120,20 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + } + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] pub fn new_unsized( diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs new file mode 100644 index 00000000..0c84b05e --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -0,0 +1,248 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue}, +}; +use itertools::Itertools; + +use crate::codegen::{ + irrt, + types::{ + ndarray::{NDArrayType, ShapeEntryType}, + structure::StructField, + ProxyType, + }, + values::{ + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct ShapeEntryValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ShapeEntryValue<'ctx> { + /// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is + /// not an instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + >::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).ndims + } + + /// Stores the number of dimensions into this value. + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.ndims_field().set(ctx, self.value, value, self.name); + } + + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).shape + } + + /// Stores the shape into this value. + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.shape_field().set(ctx, self.value, value, self.name); + } +} + +impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = ShapeEntryType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ShapeEntryValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); + + let broadcast_ndarray = + NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) + .construct_uninitialized(generator, ctx, None); + broadcast_ndarray.copy_shape_from_array( + generator, + ctx, + target_shape.base_ptr(ctx, generator), + ); + + irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); + broadcast_ndarray + } +} + +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Clone)] +pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + + /// The broadcasting shape. + pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, + + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call [`irrt::ndarray::call_nac3_ndarray_broadcast_shapes`]. +fn broadcast_shapes<'ctx, G, Shape>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + + assert!(in_shape_entries + .iter() + .all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into())); + assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into()); + + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false); + let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = unsafe { + shape_entries.ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + None, + ) + }; + let shape_entry = llvm_shape_ty.map_value(pshape_entry, None); + + let in_ndims = llvm_usize.const_int(*in_ndims, false); + shape_entry.store_ndims(ctx, in_ndims); + + shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator)); + } + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false); + irrt::ndarray::call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayType<'ctx> { + /// Broadcast all ndarrays according to + /// [`np.broadcast()`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html) + /// and return a [`BroadcastAllResult`] containing all the information of the result of the + /// broadcast operation. + pub fn broadcast( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[NDArrayValue<'ctx>], + ) -> BroadcastAllResult<'ctx, G> { + assert!(!ndarrays.is_empty()); + assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); + + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); + assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); + let broadcast_shape = ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(), + broadcast_ndims, + None, + ); + let broadcast_shape = TypedArrayLikeAdapter::from( + broadcast_shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| { + ( + ndarray.shape().as_slice_value(ctx, generator), + ndarray.get_type().ndims().unwrap(), + ) + }) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 951792fb..d91f734d 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -20,10 +20,12 @@ use crate::codegen::{ types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod broadcast; mod contiguous; mod indexing; mod nditer; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index db7acaf3..3f71b983 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -373,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1328,7 +1328,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1356,7 +1359,10 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1386,7 +1392,17 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, &shape) + } + + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, &shape) + } + + _ => unreachable!(), + }; Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9313b13b..de90a41b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -60,6 +60,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -253,6 +254,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index f5349890..41b39bb8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 05408683..90408d91 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 029815aa..f0418889 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index fc8f1aaf..72e54e02 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 34e78a2b..a8a534cd 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8f1c54fc..87692114 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1594,7 +1594,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 56c6126d..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index a82cfbad..374bcf73 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_3(n: ndarray[float, Literal[3]]): + for d in range(len(n)): + for r in range(len(n[d])): + for c in range(len(n[d][r])): + output_float64(n[d][r][c]) + def output_ndarray_float_4(n: ndarray[float, Literal[4]]): for x in range(len(n)): for y in range(len(n[x])): @@ -236,6 +242,23 @@ def test_ndarray_reshape(): output_int32(np_shape(x2)[1]) output_ndarray_int32_2(x2) +def test_ndarray_broadcast_to(): + xs = np_array([1.0, 2.0, 3.0]) + ys = np_broadcast_to(xs, (1, 3)) + zs = np_broadcast_to(ys, (2, 4, 3)) + + output_int32(np_shape(xs)[0]) + output_ndarray_float_1(xs) + + output_int32(np_shape(ys)[0]) + output_int32(np_shape(ys)[1]) + output_ndarray_float_2(ys) + + output_int32(np_shape(zs)[0]) + output_int32(np_shape(zs)[1]) + output_int32(np_shape(zs)[2]) + output_ndarray_float_3(zs) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1619,6 +1642,7 @@ def run() -> int32: test_ndarray_nd_idx() test_ndarray_reshape() + test_ndarray_broadcast_to() test_ndarray_add() test_ndarray_add_broadcast() From 7375983e0c60d04651424aa3cde509b52a20a2d3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:22:59 +0800 Subject: [PATCH 27/80] [core] codegen/ndarray: Implement np_transpose without axes argument Based on 052b67c8: core/ndstrides: implement np_transpose() (no axes argument) The IRRT implementation knows how to handle axes. But the argument is not in NAC3 yet. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/transpose.hpp | 143 ++++++++++++++++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/irrt/ndarray/transpose.rs | 48 ++++++ nac3core/src/codegen/numpy.rs | 108 ------------- nac3core/src/codegen/values/ndarray/view.rs | 50 +++++- nac3core/src/toplevel/builtins.rs | 7 +- nac3standalone/demo/src/ndarray.py | 27 ++-- 8 files changed, 267 insertions(+), 121 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/transpose.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/transpose.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 09844d3b..39ddba67 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -11,4 +11,5 @@ #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" -#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/transpose.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 00000000..662ceb1e --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace { +namespace ndarray::transpose { +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template +void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) { + if (ndims != num_axes) { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == -1) { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray* src_ndarray, NDArray* dst_ndarray, SizeT num_axes, const SizeT* axes) { + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } else { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace ndarray::transpose +} // namespace + +extern "C" { +using namespace ndarray::transpose; +void __nac3_ndarray_transpose(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int32_t num_axes, + const int32_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} + +void __nac3_ndarray_transpose64(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int64_t num_axes, + const int64_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index c640042b..ba22568e 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -22,6 +22,7 @@ pub use broadcast::*; pub use indexing::*; pub use iter::*; pub use reshape::*; +pub use transpose::*; mod array; mod basic; @@ -29,6 +30,7 @@ mod broadcast; mod indexing; mod iter; mod reshape; +mod transpose; /// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs new file mode 100644 index 00000000..57661fa7 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -0,0 +1,48 @@ +use inkwell::{values::IntValue, AddressSpace}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_transpose`. +/// +/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`. +/// +/// `dst_ndarray` must fulfill the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`. +/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, + axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); + assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + infer_and_call_function( + ctx, + &name, + None, + &[ + src_ndarray.as_base_value().into(), + dst_ndarray.as_base_value().into(), + axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), + axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { + axes.base_ptr(ctx, generator) + }) + .into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ce4e2059..fdbb716b 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1307,114 +1307,6 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (x1_ty, x1): (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = n1.size(generator, ctx); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 8334d379..450f7444 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,6 +1,6 @@ use std::iter::{once, repeat_n}; -use inkwell::values::IntValue; +use inkwell::values::{IntValue, PointerValue}; use itertools::Itertools; use crate::codegen::{ @@ -9,7 +9,7 @@ use crate::codegen::{ types::ndarray::NDArrayType, values::{ ndarray::{NDArrayValue, RustNDIndex}, - ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -108,4 +108,50 @@ impl<'ctx> NDArrayValue<'ctx> { dst_ndarray } + + /// Create a transposed view on this ndarray like + /// [`np.transpose(, = None)`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html). + /// + /// * `axes` - If specified, should be an array of the permutation (negative indices are + /// **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>, + ) -> Self { + assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); + assert!( + axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) + ); + + // Define models + let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); + + let axes = if let Some(axes) = axes { + let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false); + + // `axes = nullptr` if `axes` is unspecified. + let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None); + + Some(TypedArrayLikeAdapter::from( + axes, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + )) + } else { + None + }; + + irrt::ndarray::call_nac3_ndarray_transpose( + generator, + ctx, + *self, + transposed_ndarray, + axes.as_ref(), + ); + + transposed_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 3f71b983..538961a6 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1349,7 +1349,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg_val.into_pointer_value(), None); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.as_base_value().into())) }), ), diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 374bcf73..170ac14c 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -210,6 +210,23 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_transpose(): + x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) + y = np_transpose(x) + z = np_transpose(y) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_ndarray_float_2(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_int32(np_shape(z)[1]) + output_ndarray_float_2(z) + def test_ndarray_reshape(): w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) x = np_reshape(w, (1, 2, 1, -1)) @@ -1502,14 +1519,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) -def test_ndarray_transpose(): - x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) - y = np_transpose(x) - z = np_transpose(y) - - output_ndarray_float_2(x) - output_ndarray_float_2(y) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1641,6 +1650,7 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_transpose() test_ndarray_reshape() test_ndarray_broadcast_to() @@ -1807,7 +1817,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() - test_ndarray_transpose() test_ndarray_dot() test_ndarray_cholesky() From dcde1d9c87bc7629236c0f059eb250869871f896 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:32:34 +0800 Subject: [PATCH 28/80] [core] codegen/values/ndarray: Add more ScalarOrNDArray utils Based on f731e604: core/ndstrides: add more ScalarOrNDArray and NDArrayObject utils --- nac3core/src/codegen/values/ndarray/mod.rs | 121 ++++++++++++++++++--- 1 file changed, 107 insertions(+), 14 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d91f734d..6c8c9aab 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -12,13 +12,16 @@ use super::{ TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; -use crate::codegen::{ - irrt, - llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, - stmt::gen_for_callback_incrementing, - type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + irrt, + llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, + stmt::gen_for_callback_incrementing, + type_aligned_alloca, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, }; pub use broadcast::*; pub use contiguous::*; @@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> { self.ndims.map(|ndims| ndims == 0) } - /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. - /// Otherwise, do nothing and return the ndarray itself. - // TODO: Rename to get_unsized_element - pub fn split_unsized( + /// Returns the element present in this `ndarray` if this is unsized. + pub fn get_unsized_element( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ) -> ScalarOrNDArray<'ctx> { - let Some(is_unsized) = self.is_unsized() else { todo!() }; + ) -> Option> { + let Some(is_unsized) = self.is_unsized() else { + panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + }; if is_unsized { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; - ScalarOrNDArray::Scalar(value) + Some(value) + } else { + None + } + } + + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { + ScalarOrNDArray::Scalar(unsized_elem) } else { ScalarOrNDArray::NDArray(*self) } @@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayValue<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn from_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + ScalarOrNDArray::NDArray(ndarray) + } + + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { @@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), } } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with + /// [`NDArrayType::construct_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayValue<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => { + NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None) + } + } + } + + /// Get the dtype of the ndarray created if this were called with + /// [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.get_type(), + } + } } From 2dc5e79a23ea630f88b1bef3faa43f52e2e807ef Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 16:44:57 +0800 Subject: [PATCH 29/80] [core] codegen/ndarray: Implement subscript assignment Based on 5bed394e: core/ndstrides: implement subscript assignment Overlapping is not handled. Currently it has undefined behavior. --- nac3core/src/codegen/stmt.rs | 55 ++++++++++++++++++++++++++++-- nac3standalone/demo/src/ndarray.py | 33 ++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 3595528f..edebb4f0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -16,7 +16,11 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, + types::ndarray::NDArrayType, + values::{ + ndarray::{RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -411,7 +415,54 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + + // Process key + let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + + // Reference code: + // ```python + // target = target[key] + // value = np.asarray(value) + // + // shape = np.broadcast_shape((target, value)) + // + // target = np.broadcast_to(target, shape) + // value = np.broadcast_to(value, shape) + // + // # ...and finally copy 1-1 from value to target. + // ``` + + let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) + .map_value(target.into_pointer_value(), None); + let target = target.index(generator, ctx, &key); + + let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) + .to_ndarray(generator, ctx); + + let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] + .iter() + .filter_map(|ndims| *ndims) + .max(); + let broadcast_result = NDArrayType::new( + generator, + ctx.ctx, + value.get_type().element_type(), + broadcast_ndims, + ) + .broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 170ac14c..b668860f 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -276,6 +276,38 @@ def test_ndarray_broadcast_to(): output_int32(np_shape(zs)[2]) output_ndarray_float_3(zs) +def test_ndarray_subscript_assignment(): + xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]]) + + xs[0, 0] = 99.0 + output_ndarray_float_2(xs) + + xs[0] = 100.0 + output_ndarray_float_2(xs) + + xs[:, ::2] = 101.0 + output_ndarray_float_2(xs) + + xs[1:, 0] = 102.0 + output_ndarray_float_2(xs) + + xs[0] = np_array([-1.0, -2.0, -3.0, -4.0]) + output_ndarray_float_2(xs) + + xs[:] = np_array([-5.0, -6.0, -7.0, -8.0]) + output_ndarray_float_2(xs) + + # Test assignment with memory sharing + ys1 = np_reshape(xs, (2, 4)) + ys2 = np_transpose(ys1) + ys3 = ys2[::-1, 0] + ys3[0] = -999.0 + + output_ndarray_float_2(xs) + output_ndarray_float_2(ys1) + output_ndarray_float_2(ys2) + output_ndarray_float_1(ys3) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1653,6 +1685,7 @@ def run() -> int32: test_ndarray_transpose() test_ndarray_reshape() test_ndarray_broadcast_to() + test_ndarray_subscript_assignment() test_ndarray_add() test_ndarray_add_broadcast() From e6dab25a570643911878d244be5e07b7c162dea3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 17:15:51 +0800 Subject: [PATCH 30/80] [core] codegen/ndarray: Add NDArrayOut, broadcast_map, map Based on fbfc0b29: core/ndstrides: add NDArrayOut, broadcast_map and map --- nac3core/src/codegen/types/ndarray/map.rs | 187 +++++++++++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/values/ndarray/map.rs | 69 ++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 45 +++++ 4 files changed, 302 insertions(+) create mode 100644 nac3core/src/codegen/types/ndarray/map.rs create mode 100644 nac3core/src/codegen/values/ndarray/map.rs diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs new file mode 100644 index 00000000..0d63b22f --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -0,0 +1,187 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; +use itertools::Itertools; + +use crate::codegen::{ + stmt::gen_for_callback, + types::{ + ndarray::{NDArrayType, NDIterType}, + ProxyType, + }, + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ArrayLikeValue, ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayType<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` + /// elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when + /// iterating through the input `ndarrays` after broadcasting. The output of `mapping` is the + /// result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing + /// ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[NDArrayValue<'ctx>], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result<>::Value, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = self.broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + .construct_uninitialized(generator, ctx, None); + result_ndarray.copy_shape_from_array( + generator, + ctx, + broadcast_result.shape.base_ptr(ctx, generator), + ); + unsafe { + result_ndarray.create_data(generator, ctx); + } + result_ndarray + } + + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| { + NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) + }) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_element()`. + // `in_nditers`' `has_element()`s should return the same value. + Ok(out_nditer.has_element(generator, ctx)) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + let in_scalars = + in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a + /// scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars + /// or ndarrays or a mix of them as their inputs and produces either an ndarray with broadcast, + /// or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following + /// behaviors: + /// + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is + /// converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> + /// ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a + /// [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be + /// 'as-ndarray'-ed into ndarrays, then all inputs (now all ndarrays) will be passed to + /// [`NDArrayValue::broadcasting_starmap`] and **create** a new ndarray with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: BasicTypeEnum<'ctx>, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = + inputs.iter().map(BasicValueEnum::<'ctx>::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().copied().collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(value)) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ret_dtype, + &inputs.iter().map(NDArrayValue::get_type).collect_vec(), + ) + .broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 316d0f33..43712be1 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -30,6 +30,7 @@ mod broadcast; mod contiguous; pub mod factory; mod indexing; +mod map; mod nditer; /// Proxy type for a `ndarray` type in LLVM. diff --git a/nac3core/src/codegen/values/ndarray/map.rs b/nac3core/src/codegen/values/ndarray/map.rs new file mode 100644 index 00000000..72d1bf9d --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/map.rs @@ -0,0 +1,69 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; + +use crate::codegen::{ + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + self.get_type().broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a + /// [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new + /// ndarray of the results will be created and returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: BasicTypeEnum<'ctx>, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 6c8c9aab..89f88e74 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -31,6 +31,7 @@ pub use nditer::*; mod broadcast; mod contiguous; mod indexing; +mod map; mod nditer; pub mod shape; mod view; @@ -540,6 +541,26 @@ impl<'ctx> NDArrayValue<'ctx> { ScalarOrNDArray::NDArray(*self) } } + + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_shape: impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) { + let ndarray_shape = self.shape(); + let output_shape = out_shape; + + irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + &ndarray_shape, + &output_shape, + ); + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { @@ -1081,3 +1102,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } } + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a +/// function will create a new ndarray and store the result in it. +#[derive(Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: BasicTypeEnum<'ctx> }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayValue<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +} From 6cbba8fdde439afad1cfab65cb3267ee483e043a Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 11:24:28 +0800 Subject: [PATCH 31/80] [core] codegen: Reimplement builtin funcs to support strided ndarrays Based on 7f3c4530: core/ndstrides: update builtin_fns to use ndarray with strides --- nac3core/src/codegen/builtin_fns.rs | 959 +++++++++++----------------- 1 file changed, 368 insertions(+), 591 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 54650ab3..7c8ad7a8 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -11,19 +11,16 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - numpy, - numpy::ndarray_elementwise_unaryop_impl, - stmt::gen_for_callback_incrementing, types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; use crate::{ toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, }, typecheck::typedef::{Type, TypeEnum}, @@ -129,18 +126,18 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_int32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -189,18 +186,18 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_int64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -265,18 +262,18 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_uint32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -330,18 +327,18 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_uint64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -355,7 +352,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { @@ -394,20 +390,19 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndims = extract_ndims(&ctx.unifier, ndims); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_float(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -440,18 +435,20 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty.into() }, + |generator, ctx, scalar| { + call_round(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -477,18 +474,18 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_numpy_round(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -539,22 +536,21 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (elem_ty, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -591,18 +587,20 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -639,18 +637,20 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -741,42 +741,37 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -861,23 +856,26 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => codegen_unreachable!(ctx), } } + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); - let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = n.size(generator, ctx); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None); + let llvm_dtype = ndarray.get_type().element_type(); + + let zero = llvm_usize.const_zero(); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx + let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") .unwrap(); ctx.make_assert( generator, - n_sz_eqz, + size_nez, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -885,54 +883,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = - generator.gen_var_alloc(ctx, llvm_ndarray_ty.element_type(), None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, llvm_dtype, None)?; + let extremum_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = unsafe { ndarray.data().get_unchecked(ctx, generator, &zero, None) }; + ctx.builder.build_store(extremum, first_value).unwrap(); + ctx.builder.build_store(extremum_idx, zero).unwrap(); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { - n.data().get_unchecked( - ctx, - generator, - &ctx.builder - .build_int_truncate_or_bit_cast(idx, llvm_usize, "") - .unwrap(), - None, - ) - }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |_, ctx, _, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = ctx + .builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); - let result = match fn_name { + let curr_value = nditer.get_scalar(ctx); + let curr_idx = nditer.get_index(ctx); + + let new_extremum = match fn_name { "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_min(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_max(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } _ => codegen_unreachable!(ctx), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), @@ -942,24 +929,35 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + ctx.builder.build_store(extremum_idx, new_extremum_idx).unwrap(); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => ctx + .builder + .build_int_s_extend_or_bit_cast( + ctx.builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + ctx.ctx.i64_type(), + "", + ) + .unwrap() + .into(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => codegen_unreachable!(ctx), } } @@ -1006,42 +1004,37 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1074,39 +1067,20 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = ScalarOrNDArray::from_value(generator, ctx, (arg_ty, arg_val)); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(x, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; + let ret_ty = get_ret_elem_type(ctx, dtype); + let llvm_ret_ty = ctx.get_llvm_type(generator, ret_ty); + let result = arg.map(generator, ctx, llvm_ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; - Ok(result) + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1431,59 +1405,29 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. @@ -1495,59 +1439,29 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. @@ -1559,59 +1473,29 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. @@ -1623,59 +1507,29 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. @@ -1687,48 +1541,31 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert_eq!(x1.get_dtype(), ctx.ctx.f64_type().into()); + debug_assert_eq!(x2.get_dtype(), ctx.ctx.i32_type().into()); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. @@ -1740,59 +1577,29 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1804,59 +1611,29 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_nextafter"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function From 59f19e29df1f9765ff6b83cffd3a9ca5d450b145 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:25:35 +0800 Subject: [PATCH 32/80] [core] codegen: Reimplement ndarray binop Based on 9e40c834: core/ndstrides: implement binop --- nac3core/src/codegen/expr.rs | 163 +++++++++++++++++------------------ 1 file changed, 80 insertions(+), 83 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8d7f8e35..8225df7c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,14 +34,19 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{ + helper::{arraylike_flatten_element_type, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, TopLevelDef, + }, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -1526,98 +1531,90 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); + let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if is_ndarray1 && is_ndarray2 { + if op.base == Operator::MatMult { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); - let left_val = NDArrayType::from_unifier_type(generator, ctx, ty1) - .map_value(left_val.into_pointer_value(), None); - let right_val = NDArrayType::from_unifier_type(generator, ctx, ty2) - .map_value(right_val.into_pointer_value(), None); - - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (ty1, left_val.as_base_value().into(), false), - (ty2, right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; - - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = - NDArrayType::from_unifier_type(generator, ctx, if is_ndarray1 { ty1 } else { ty2 }) - .map_value( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( + // MatMult is the only binop which is not an elementwise op + let result = numpy::ndarray_matmul_2d( generator, ctx, - ndarray_dtype, + ndarray_dtype1, match op.variant { BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (ty1, left_val, !is_ndarray1), - (ty2, right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype), lhs), - op, - (&Some(ndarray_dtype), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + BinopVariant::AugAssign => Some(left), }, + left, + right, )?; - Ok(Some(res.as_base_value().into())) + Ok(Some(result.as_base_value().into())) + } else { + // For other operations, they are all elementwise operations. + + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // If this is an augmented assignment. + // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap(generator, ctx, &[left, right], out, |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(ty1_dtype), left_value), + op, + (&Some(ty2_dtype), right_value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, common_dtype)?; + + Ok(result) + }) + .unwrap(); + Ok(Some(result.as_base_value().into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); From a2f1b25fd881ca34d34ad3304f67fc4b95e54aca Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:37:17 +0800 Subject: [PATCH 33/80] [core] codegen: Reimplement ndarray unary op Based on bb992704: core/ndstrides: implement unary op --- nac3core/src/codegen/expr.rs | 18 +++++------ nac3core/src/codegen/numpy.rs | 59 ----------------------------------- 2 files changed, 8 insertions(+), 69 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8225df7c..232c68c2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1777,10 +1777,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) + .map_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1798,20 +1798,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? + .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; - res.as_base_value().into() + mapped_ndarray.as_base_value().into() } else { unimplemented!() })) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index fdbb716b..d02103ab 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -195,28 +195,6 @@ where }) } -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of /// the target `ndarray`. fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( @@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - /// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output From ebbadc2d742ed1c7e8901e94fbf07fc1721c5401 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:46:24 +0800 Subject: [PATCH 34/80] [core] codegen: Reimplement ndarray cmpop Based on 56cccce1: core/ndstrides: implement cmpop --- nac3core/irrt/irrt/ndarray.hpp | 70 ------- nac3core/src/codegen/expr.rs | 109 ++++------ nac3core/src/codegen/irrt/ndarray/mod.rs | 170 +--------------- nac3core/src/codegen/numpy.rs | 248 +---------------------- 4 files changed, 43 insertions(+), 554 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 7fc9a63b..534f18d6 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n stride *= dims[i]; } } - -template -void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, - SizeT src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} } // namespace extern "C" { @@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32 void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } - -void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} } \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 232c68c2..8a002bb3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1852,83 +1852,52 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty); + let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left)) + .to_ndarray(generator, ctx); + let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right)) + .to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ctx.ctx.i8_type().into(), + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty) - .map_value(lhs.into_pointer_value(), None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, left_val.as_base_value().into(), false), - (right_ty, rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty_dtype), left_scalar), + &[op], + &[(Some(right_ty_dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, lhs, !is_ndarray1), - (right_ty, rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.as_base_value().into())); } } diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index ba22568e..151795c5 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,18 +1,15 @@ use inkwell::{ types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, IntPredicate, + AddressSpace, }; use itertools::Either; use super::get_usize_dependent_function_name; use crate::codegen::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, values::{ ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( |_, _, v| v.into(), ) } - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); - let ndarray_calc_broadcast_fn = - ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.shape().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.shape().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into()) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.shape().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d02103ab..9fe5a972 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -10,10 +10,7 @@ use super::{ expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, - ndarray::{ - call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, - call_ndarray_calc_nd_indices, call_ndarray_calc_size, - }, + ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, }, llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, @@ -21,7 +18,7 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, @@ -195,152 +192,6 @@ where }) } -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - (lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - (rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Returns the element of an ndarray indexed by the given indices, performing int-promotion on - // `indices` where necessary. - // - // Required for compatibility with `NDArrayType::get_unchecked`. - let get_data_by_indices_compat = - |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { - let llvm_usize = generator.get_size_type(ctx.ctx); - - // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT - let stackptr = llvm_intrinsics::call_stacksave(ctx, None); - let indices = if llvm_usize == ctx.ctx.i32_type() { - indices - } else { - let indices_usize = TypedArrayLikeAdapter::>::from( - ArraySliceValue::from_ptr_val( - ctx.builder - .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") - .unwrap(), - indices.size(ctx, generator), - None, - ), - |_, _, val| val.into_int_value(), - |_, _, val| val.into(), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (indices.size(ctx, generator), false), - |generator, ctx, _, i| { - let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; - let idx = ctx - .builder - .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") - .unwrap(); - unsafe { - indices_usize.set_typed_unchecked(ctx, generator, &i, idx); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - indices_usize - }; - - let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; - - llvm_intrinsics::call_stackrestore(ctx, stackptr); - - elem - }; - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (Type, BasicValueEnum<'ctx>, bool), - rhs: (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let (lhs_ty, lhs_val, lhs_scalar) = lhs; - let (rhs_ty, rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayType::from_unifier_type( - generator, - ctx, - if lhs_scalar { rhs_ty } else { lhs_ty }, - ) - .map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// /// * `elem_ty` - The element type of the `NDArray`. From 66b8a5e01d73d84f1f1312f6915e3977ec7cc24c Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:52 +0800 Subject: [PATCH 35/80] [core] codegen/ndarray: Reimplement matmul Based on 73c2203b: core/ndstrides: implement general matmul --- nac3core/irrt/irrt.cpp | 4 +- nac3core/irrt/irrt/int_types.hpp | 2 - nac3core/irrt/irrt/ndarray.hpp | 50 -- nac3core/irrt/irrt/ndarray/matmul.hpp | 98 +++ nac3core/src/codegen/expr.rs | 67 +- nac3core/src/codegen/irrt/ndarray/matmul.rs | 66 ++ nac3core/src/codegen/irrt/ndarray/mod.rs | 131 +--- nac3core/src/codegen/numpy.rs | 727 +----------------- nac3core/src/codegen/types/ndarray/factory.rs | 2 +- nac3core/src/codegen/values/ndarray/matmul.rs | 334 ++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 1 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 84 +- 17 files changed, 585 insertions(+), 995 deletions(-) delete mode 100644 nac3core/irrt/irrt/ndarray.hpp create mode 100644 nac3core/irrt/irrt/ndarray/matmul.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/matmul.rs create mode 100644 nac3core/src/codegen/values/ndarray/matmul.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 39ddba67..87dcb428 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,7 +1,6 @@ #include "irrt/exception.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" -#include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" #include "irrt/string.hpp" @@ -12,4 +11,5 @@ #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/broadcast.hpp" -#include "irrt/ndarray/transpose.hpp" \ No newline at end of file +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index ed8a48b8..17ccf604 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64); #endif -// NDArray indices are always `uint32_t`. -using NDIndexInt = uint32_t; // The type of an index or a value describing the length of a range/slice is always `int32_t`. using SliceIndex = int32_t; diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp deleted file mode 100644 index 534f18d6..00000000 --- a/nac3core/irrt/irrt/ndarray.hpp +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "irrt/int_types.hpp" - -// TODO: To be deleted since NDArray with strides is done. - -namespace { -template -SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} -} // namespace - -extern "C" { -uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t -__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} -} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..b0fd4d86 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray::matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace ndarray::matmul +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8a002bb3..4b83e63e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -27,7 +27,7 @@ use super::{ call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -1534,26 +1534,35 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if op.base == Operator::MatMult { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // Augmented assignment - `left` has to be an ndarray. If it were a scalar then NAC3 + // simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + if op.base == Operator::MatMult { let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); - - // MatMult is the only binop which is not an elementwise op - let result = numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left), - }, - left, - right, - )?; - - Ok(Some(result.as_base_value().into())) + let result = left + .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) + .split_unsized(generator, ctx); + Ok(Some(result.to_basic_value_enum().into())) } else { // For other operations, they are all elementwise operations. @@ -1565,28 +1574,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // For all cases, the scalar operand is promoted to an ndarray, // the two are then broadcasted, and starmapped through. - let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); - let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); - - // Inhomogeneous binary operations are not supported. - assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); - - let common_dtype = ty1_dtype; - let llvm_common_dtype = left.get_dtype(); - - let out = match op.variant { - BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - BinopVariant::AugAssign => { - // If this is an augmented assignment. - // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. - if let ScalarOrNDArray::NDArray(out_ndarray) = left { - NDArrayOut::WriteToNDArray { ndarray: out_ndarray } - } else { - panic!("left must be an ndarray") - } - } - }; - let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs new file mode 100644 index 00000000..551cb7c7 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -0,0 +1,66 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, + values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`. +/// +/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of +/// `a @ b`. +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + final_ndims: IntValue<'ctx>, + new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!( + BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + + infer_and_call_function( + ctx, + &name, + None, + &[ + a_shape.size(ctx, generator).into(), + a_shape.base_ptr(ctx, generator).into(), + b_shape.size(ctx, generator).into(), + b_shape.base_ptr(ctx, generator).into(), + final_ndims.into(), + new_a_shape.base_ptr(ctx, generator).into(), + new_b_shape.base_ptr(ctx, generator).into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 151795c5..b1530685 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,23 +1,9 @@ -use inkwell::{ - types::BasicTypeEnum, - values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, -}; -use itertools::Either; - -use super::get_usize_dependent_function_name; -use crate::codegen::{ - values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAdapter, - }, - CodeGenContext, CodeGenerator, -}; pub use array::*; pub use basic::*; pub use broadcast::*; pub use indexing::*; pub use iter::*; +pub use matmul::*; pub use reshape::*; pub use transpose::*; @@ -26,119 +12,6 @@ mod basic; mod broadcast; mod indexing; mod iter; +mod matmul; mod reshape; mod transpose; - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns a -/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); - assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); - assert_eq!( - BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - - let ndarray_calc_size_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The `llvm_usize` index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert_eq!(index.get_type(), llvm_usize); - - let ndarray_calc_nd_indices_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9fe5a972..d46a6119 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,736 +1,23 @@ use inkwell::{ - types::BasicType, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - IntPredicate, OptimizationLevel, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use super::{ - expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, - ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, - }, - llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{factory::ndarray_zero_value, NDArrayType}, - values::{ - ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, - TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, - UntypedArrayLikeMutator, - }, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, num_dims, None); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = ndarray.size(generator, ctx); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); - - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let src_data_offset = ctx - .builder - .build_int_mul( - src_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - let dst_data_offset = ctx - .builder - .build_int_mul( - dst_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = - if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.shape().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { ndarray.create_data(generator, ctx) }; - - ndarray - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - }; - let rhs_dim0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_usize.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 300167f7..2d0dca76 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -12,7 +12,7 @@ use crate::{ }; /// Get the zero value in `np.zeros()` of a `dtype`. -pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs new file mode 100644 index 00000000..88a94394 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -0,0 +1,334 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; + +use super::{NDArrayOut, NDArrayValue, RustNDIndex}; +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, + irrt, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ + ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::arraylike_flatten_element_type, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), + (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), +) -> NDArrayValue<'ctx> { + assert!( + in_a.ndims.is_some_and(|ndims| ndims >= 2), + "in_a (which is {:?}) must be compile-time known and >= 2", + in_a.ndims + ); + assert!( + in_b.ndims.is_some_and(|ndims| ndims >= 2), + "in_b (which is {:?}) must be compile-time known and >= 2", + in_b.ndims + ); + + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims = llvm_usize.const_int(ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_a.shape().base_ptr(ctx, generator), + in_lhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_b.shape().base_ptr(ctx, generator), + in_rhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let dst_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Matmul dimension compatibility is checked here. + irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + &in_lhs_shape, + &in_rhs_shape, + ndims, + &lhs_shape, + &rhs_shape, + &dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); + + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, None); + dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); + unsafe { + dst.create_data(generator, ctx); + } + + (lhs, rhs, dst) + }; + + let len = unsafe { + lhs.shape().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(ndims_int - 1, false), + None, + ) + }; + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices::(); + let i = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None) + }; + let j = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None) + }; + + let num_0 = llvm_usize.const_int(0, false); + let num_1 = llvm_usize.const_int(1, false); + + gen_for_callback_incrementing( + generator, + ctx, + None, + num_0, + (len, false), + |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + k.into(), + ); + } + let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) }; + + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + k.into(), + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) }; + + // Restore `indices`. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs_dtype), a_ik), + Binop::normal(Operator::Mult), + (&Some(rhs_dtype), b_kj), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }, + num_1, + ) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). + /// + /// This function always return an [`NDArrayValue`]. You may want to use + /// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + #[must_use] + pub fn matmul( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + self_ty: Type, + (other_ty, other): (Type, Self), + (out_dtype, out): (Type, NDArrayOut<'ctx>), + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!( + self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), + "np.matmul disallows scalar input" + ); + + // If both arguments are 2-D they are multiplied like conventional matrices. + // + // If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the + // last two indices and broadcast accordingly. + // + // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its + // dimensions. After matrix multiplication the prepended 1 is removed. + // + // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its + // dimensions. After matrix multiplication the appended 1 is removed. + + let new_a = if self.ndims.unwrap() == 1 { + // Prepend 1 to its dimensions + self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + *self + }; + + let new_b = if other.ndims.unwrap() == 1 { + // Append 1 to its dimensions + other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + other + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = + matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b)); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = ctx.ctx.i32_type().const_zero(); + + if self.ndims.unwrap() == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if other.ndims.unwrap() == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.shape(); + out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 89f88e74..707c79a2 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -32,6 +32,7 @@ mod broadcast; mod contiguous; mod indexing; mod map; +mod matmul; mod nditer; pub mod shape; mod view; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 41b39bb8..4332b474 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 90408d91..60e0c194 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index f0418889..46601817 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 72e54e02..da58d121 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index a8a534cd..8f384fa1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f03..40bbdeab 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; use super::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, }; @@ -175,19 +175,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +187,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +530,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); From 3ac1083734ff093c77244cc675efb2b9151a56ae Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:32:18 +0800 Subject: [PATCH 36/80] [core] codegen: Reimplement np_dot() for scalars and 1D Based on 693b7f37: core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 136 +++++++++++++++++------------- nac3core/src/toplevel/builtins.rs | 4 +- 2 files changed, 79 insertions(+), 61 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d46a6119..e5a893c9 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,14 +7,18 @@ use nac3parser::ast::StrRef; use super::{ macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, + stmt::gen_for_callback, + types::ndarray::{NDArrayType, NDIterType}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{arraylike_flatten_element_type, extract_ndims}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::typedef::{FunSignature, Type}, }; @@ -300,89 +304,101 @@ pub fn gen_ndarray_fill<'ctx>( pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); + let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = n1.size(generator, ctx); - let n2_sz = n2.size(generator, ctx); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = + ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size), Some(b_size), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx)) + }, + |_, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(ctx); + let b_scalar = b_iter.get_scalar(ctx); - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + _ => codegen_unreachable!( ctx, "{FN_NAME}() not supported for '{}'", diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 538961a6..600276b7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1935,10 +1935,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), From 12fddc3533dcc4c88bac0dfd91a46cf1930723b6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:48:00 +0800 Subject: [PATCH 37/80] [core] codegen/ndarray: Make ndims non-optional Now that everything is ported to use strided impl, dynamic-ndim ndarray instances do not exist anymore. --- nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 2 +- nac3core/src/codegen/builtin_fns.rs | 28 +++---- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 23 +++--- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/codegen/test.rs | 2 +- nac3core/src/codegen/types/ndarray/array.rs | 14 ++-- nac3core/src/codegen/types/ndarray/map.rs | 2 +- nac3core/src/codegen/types/ndarray/mod.rs | 54 ++++--------- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- .../src/codegen/values/ndarray/broadcast.rs | 17 ++--- .../src/codegen/values/ndarray/contiguous.rs | 6 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 33 +++----- nac3core/src/codegen/values/ndarray/mod.rs | 76 +++---------------- nac3core/src/codegen/values/ndarray/view.rs | 11 +-- 17 files changed, 94 insertions(+), 201 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 653f41a3..156ba23e 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -464,7 +464,7 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) .map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -597,7 +597,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 6507dc20..8e9cd10c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1107,7 +1107,7 @@ impl InnerResolver { self.global_value_ids.write().insert(id, obj.into()); } - let ndims = llvm_ndarray.ndims().unwrap(); + let ndims = llvm_ndarray.ndims(); // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 7c8ad7a8..9d368070 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1652,7 +1652,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1694,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1746,8 +1746,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1796,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,7 +1838,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_dyn_shape(generator, ctx, &[d0, d1], None); unsafe { out.create_data(generator, ctx) }; @@ -1880,7 +1880,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1924,7 +1924,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None); + let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, ndims, llvm_usize, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1940,7 +1940,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,7 +1979,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)) + let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) .construct_const_shape(generator, ctx, &[1], None); unsafe { det.create_data(generator, ctx) }; @@ -2008,13 +2008,13 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2053,13 +2053,13 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2ce3c9ab..4e83d530 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -520,7 +520,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into() + NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e5a893c9..6c16be9f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,8 +120,13 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) - .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + generator, + context, + &shape, + fill_value_arg, + None, + ); Ok(ndarray.as_base_value()) } @@ -218,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -246,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -315,8 +320,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. - assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); - assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert_eq!(a.get_type().ndims(), 1); + assert_eq!(b.get_type().ndims(), 1); let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); // Check shapes. diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index edebb4f0..2a3bd066 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -447,10 +447,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) .to_ndarray(generator, ctx); - let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] - .iter() - .filter_map(|ndims| *ndims) - .max(); + let broadcast_ndims = + [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( generator, ctx.ctx, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 97bd3f09..2701e138 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -464,6 +464,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None); + let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 87cd002a..0f30f0eb 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -41,7 +41,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); let list_value = list.as_i8_list(generator, ctx); @@ -61,7 +61,7 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) .construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -93,12 +93,12 @@ impl<'ctx> NDArrayType<'ctx> { if ndims == 1 { // `list` is not nested assert_eq!(ndims, 1); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert!(self.ndims >= ndims); assert_eq!(dtype, self.dtype); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + let ndarray = Self::new(generator, ctx.ctx, dtype, 1) .construct_uninitialized(generator, ctx, name); // Set data @@ -170,7 +170,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. @@ -183,9 +183,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { assert_eq!(ndarray.get_type().dtype, self.dtype); - assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self - .ndims - .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert!(self.ndims >= ndarray.get_type().ndims); assert_eq!(copy.get_type(), ctx.ctx.bool_type()); let ndarray_val = gen_if_else_expr_callback( diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 0d63b22f..bf82b4da 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -47,7 +47,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 43712be1..353ace33 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -38,7 +38,7 @@ mod nditer; pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, } @@ -113,7 +113,7 @@ impl<'ctx> NDArrayType<'ctx> { generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, ) -> Self { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); @@ -132,7 +132,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. @@ -145,7 +145,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize } + NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -164,7 +164,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: Self::llvm_type(ctx.ctx, llvm_usize), dtype: llvm_dtype, - ndims: Some(ndims), + ndims, llvm_usize, } } @@ -174,7 +174,7 @@ impl<'ctx> NDArrayType<'ctx> { pub fn from_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); @@ -196,7 +196,7 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the number of dimensions of this `ndarray` type. #[must_use] - pub fn ndims(&self) -> Option { + pub fn ndims(&self) -> u64 { self.ndims } @@ -286,35 +286,7 @@ impl<'ctx> NDArrayType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - - let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else { - unreachable!() - }; - - self.construct_impl(generator, ctx, ndims, name) - } - - /// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`. - /// - /// `shape` and `strides` will be automatically allocated onto the stack. - /// - /// The returned ndarray's content will be: - /// - `data`: uninitialized. - /// - `itemsize`: set to the size of `dtype`. - /// - `ndims`: set to the value of `ndims`. - /// - `shape`: allocated with an array of length `ndims` with uninitialized values. - /// - `strides`: allocated with an array of length `ndims` with uninitialized values. - #[deprecated = "Prefer construct_uninitialized or construct_*_shape."] - #[must_use] - pub fn construct_dyn_ndims( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> >::Value { - assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)"); + let ndims = self.llvm_usize.const_int(self.ndims, false); self.construct_impl(generator, ctx, ndims, name) } @@ -330,9 +302,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[u64], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -365,9 +337,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[IntValue<'ctx>], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -407,7 +379,7 @@ impl<'ctx> NDArrayType<'ctx> { let value = value.as_basic_value_enum(); assert_eq!(value.get_type(), self.dtype); - assert!(self.ndims.is_none_or(|ndims| ndims == 0)); + assert_eq!(self.ndims, 0); // We have to put the value on the stack to get a data pointer. let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap(); diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 9b71693a..c77e4571 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,13 +163,8 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - assert!( - ndarray.get_type().ndims().is_some(), - "NDIter requires ndims of NDArray to be known." - ); - let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 0c84b05e..1b99f464 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -101,12 +101,11 @@ impl<'ctx> NDArrayValue<'ctx> { target_ndims: u64, target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> Self { - assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = - NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) - .construct_uninitialized(generator, ctx, None); + let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, ctx, @@ -199,14 +198,13 @@ impl<'ctx> NDArrayType<'ctx> { ndarrays: &[NDArrayValue<'ctx>], ) -> BroadcastAllResult<'ctx, G> { assert!(!ndarrays.is_empty()); - assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); let llvm_usize = generator.get_size_type(ctx.ctx); // Infer the broadcast output ndims. let broadcast_ndims_int = - ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); - assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap(); + assert!(self.ndims() >= broadcast_ndims_int); let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); let broadcast_shape = ArraySliceValue::from_ptr_val( @@ -223,10 +221,7 @@ impl<'ctx> NDArrayType<'ctx> { let shape_entries = ndarrays .iter() .map(|ndarray| { - ( - ndarray.shape().as_slice_value(ctx, generator), - ndarray.get_type().ndims().unwrap(), - ) + (ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims()) }) .collect_vec(); broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index f3b03dd1..8eb700b9 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -121,9 +121,7 @@ impl<'ctx> NDArrayValue<'ctx> { .alloca_var(generator, ctx, self.name); // Set ndims and shape. - let ndims = self - .ndims - .map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false)); + let ndims = self.llvm_usize.const_int(self.ndims, false); result.store_ndims(ctx, ndims); let shape = self.shape(); @@ -180,7 +178,7 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) .construct_uninitialized(generator, ctx, carray.name); // Copy shape and update strides diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3d575028..3821f232 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -98,8 +98,8 @@ impl<'ctx> From> for PointerValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Get the expected `ndims` after indexing with `indices`. #[must_use] - fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option { - let mut ndims = self.ndims?; + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; for index in indices { match index { @@ -113,7 +113,7 @@ impl<'ctx> NDArrayValue<'ctx> { } } - Some(ndims) + ndims } /// Index into the ndarray, and return a newly-allocated view on this ndarray. @@ -127,8 +127,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, indices: &[RustNDIndex<'ctx>], ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - let dst_ndims = self.deduce_ndims_after_indexing_with(indices); let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index 88a94394..f802c0c0 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -29,16 +29,8 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), ) -> NDArrayValue<'ctx> { - assert!( - in_a.ndims.is_some_and(|ndims| ndims >= 2), - "in_a (which is {:?}) must be compile-time known and >= 2", - in_a.ndims - ); - assert!( - in_b.ndims.is_some_and(|ndims| ndims >= 2), - "in_b (which is {:?}) must be compile-time known and >= 2", - in_b.ndims - ); + assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims); + assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); @@ -47,13 +39,13 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); // Deduce ndims of the result of matmul. - let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims_int = max(in_a.ndims, in_b.ndims); let ndims = llvm_usize.const_int(ndims_int, false); // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the // destination ndarray to store the result of matmul. let (lhs, rhs, dst) = { - let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false); let in_lhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_a.shape().base_ptr(ctx, generator), @@ -63,7 +55,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( |_, _, val| val.into_int_value(), |_, _, val| val.into(), ); - let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false); let in_rhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_b.shape().base_ptr(ctx, generator), @@ -116,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { @@ -266,10 +258,7 @@ impl<'ctx> NDArrayValue<'ctx> { (out_dtype, out): (Type, NDArrayOut<'ctx>), ) -> Self { // Sanity check, but type inference should prevent this. - assert!( - self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), - "np.matmul disallows scalar input" - ); + assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input"); // If both arguments are 2-D they are multiplied like conventional matrices. // @@ -282,14 +271,14 @@ impl<'ctx> NDArrayValue<'ctx> { // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its // dimensions. After matrix multiplication the appended 1 is removed. - let new_a = if self.ndims.unwrap() == 1 { + let new_a = if self.ndims == 1 { // Prepend 1 to its dimensions self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) } else { *self }; - let new_b = if other.ndims.unwrap() == 1 { + let new_b = if other.ndims == 1 { // Append 1 to its dimensions other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) } else { @@ -305,12 +294,12 @@ impl<'ctx> NDArrayValue<'ctx> { let mut postindices = vec![]; let zero = ctx.ctx.i32_type().const_zero(); - if self.ndims.unwrap() == 1 { + if self.ndims == 1 { // Remove the prepended 1 postindices.push(RustNDIndex::SingleElement(zero)); } - if other.ndims.unwrap() == 1 { + if other.ndims == 1 { // Remove the appended 1 postindices.push(RustNDIndex::Ellipsis); postindices.push(RustNDIndex::SingleElement(zero)); diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 707c79a2..ad50d32e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -42,7 +42,7 @@ mod view; pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -62,7 +62,7 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn from_pointer_value( ptr: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { @@ -245,26 +245,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_shape = src_ndarray.shape().base_ptr(ctx, generator); self.copy_shape_from_array(generator, ctx, src_shape); @@ -296,26 +277,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_strides = src_ndarray.strides().base_ptr(ctx, generator); self.copy_strides_from_array(generator, ctx, src_strides); @@ -380,11 +342,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Self { - let clone = if self.ndims.is_some() { - self.get_type().construct_uninitialized(generator, ctx, None) - } else { - self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None) - }; + let clone = self.get_type().construct_uninitialized(generator, ctx, None); let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); @@ -437,11 +395,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.shape().get_typed_unchecked( @@ -459,7 +415,7 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } @@ -473,11 +429,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.strides().get_typed_unchecked( @@ -495,15 +449,15 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] - pub fn is_unsized(&self) -> Option { - self.ndims.map(|ndims| ndims == 0) + pub fn is_unsized(&self) -> bool { + self.ndims == 0 } /// Returns the element present in this `ndarray` if this is unsized. @@ -512,11 +466,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Option> { - let Some(is_unsized) = self.is_unsized() else { - panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - }; - - if is_unsized { + if self.is_unsized() { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; @@ -534,8 +484,6 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ScalarOrNDArray<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { ScalarOrNDArray::Scalar(unsized_elem) } else { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 450f7444..5027be58 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -26,9 +26,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndmin: u64, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - - let ndims = self.ndims.unwrap(); + let ndims = self.ndims; if ndims < ndmin { // Extend the dimensions with np.newaxis. @@ -67,13 +65,13 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); // Resolve negative indices let size = self.size(generator, ctx); - let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_shape = dst_ndarray.shape(); irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( generator, @@ -121,7 +119,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, axes: Option>, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); assert!( axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) ); @@ -130,7 +127,7 @@ impl<'ctx> NDArrayValue<'ctx> { let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); let axes = if let Some(axes) = axes { - let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false); + let num_axes = self.llvm_usize.const_int(self.ndims, false); // `axes = nullptr` if `axes` is unspecified. let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None); From e480081e4b9d95831958c65721ed42d2d32e3989 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 4 Jan 2025 10:28:27 +0800 Subject: [PATCH 38/80] update dependencies --- Cargo.lock | 103 +++++++++++++++++++++++++++-------------------------- flake.lock | 6 ++-- 2 files changed, 55 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ce59893f..646700d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.4" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ "shlex", ] @@ -170,7 +170,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -187,14 +187,14 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static", "libc", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] @@ -221,18 +221,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -249,18 +249,18 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -305,9 +305,9 @@ dependencies = [ [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "equivalent" @@ -378,9 +378,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "hashbrown" @@ -417,11 +417,11 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -472,7 +472,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -559,9 +559,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" @@ -678,7 +678,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", "trybuild", ] @@ -799,7 +799,7 @@ dependencies = [ "phf_shared 0.11.2", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -927,7 +927,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -940,14 +940,14 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1062,9 +1062,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -1095,29 +1095,29 @@ checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.134" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" dependencies = [ "itoa", "memchr", @@ -1226,7 +1226,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -1242,9 +1242,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" dependencies = [ "proc-macro2", "quote", @@ -1265,12 +1265,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1278,9 +1279,9 @@ dependencies = [ [[package]] name = "term" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0" +checksum = "a3bb6001afcea98122260987f8b7b5da969ecad46dbf0b5453702f776b491a41" dependencies = [ "home", "windows-sys 0.52.0", @@ -1325,7 +1326,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] [[package]] @@ -1603,9 +1604,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" dependencies = [ "memchr", ] @@ -1637,5 +1638,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.94", ] diff --git a/flake.lock b/flake.lock index 67bb80e4..7672c219 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1733940404, - "narHash": "sha256-Pj39hSoUA86ZePPF/UXiYHHM7hMIkios8TYG29kQT4g=", + "lastModified": 1735834308, + "narHash": "sha256-dklw3AXr3OGO4/XT1Tu3Xz9n/we8GctZZ75ZWVqAVhk=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5d67ea6b4b63378b9c13be21e2ec9d1afc921713", + "rev": "6df24922a1400241dae323af55f30e4318a6ca65", "type": "github" }, "original": { From 8322d457c601a87fbaa2af735a819a0c1c4622e4 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 4 Jan 2025 15:30:24 +0800 Subject: [PATCH 39/80] standalone/demo: numpy2 compatibility --- nac3standalone/demo/interpret_demo.py | 8 ++++++-- nac3standalone/demo/src/numeric_primitives.py | 20 ++++++------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 8784ce53..fa91ed3c 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -67,7 +67,7 @@ def _bool(x): def _float(x): if isinstance(x, np.ndarray): - return np.float_(x) + return np.float64(x) else: return float(x) @@ -111,6 +111,9 @@ def patch(module): def output_strln(x): print(x, end='') + def output_int32_list(x): + print([int(e) for e in x]) + def dbg_stack_address(_): return 0 @@ -126,11 +129,12 @@ def patch(module): return output_float elif name == "output_str": return output_strln + elif name == "output_int32_list": + return output_int32_list elif name in { "output_bool", "output_int32", "output_int64", - "output_int32_list", "output_uint32", "output_uint64", "output_strln", diff --git a/nac3standalone/demo/src/numeric_primitives.py b/nac3standalone/demo/src/numeric_primitives.py index 77a641f8..2c08c16d 100644 --- a/nac3standalone/demo/src/numeric_primitives.py +++ b/nac3standalone/demo/src/numeric_primitives.py @@ -29,10 +29,10 @@ def u32_max() -> uint32: return ~uint32(0) def i32_min() -> int32: - return int32(1 << 31) + return int32(-(1 << 31)) def i32_max() -> int32: - return int32(~(1 << 31)) + return int32((1 << 31)-1) def u64_min() -> uint64: return uint64(0) @@ -63,8 +63,9 @@ def test_conv_from_i32(): i32_max() ]: output_int64(int64(x)) - output_uint32(uint32(x)) - output_uint64(uint64(x)) + if x >= 0: + output_uint32(uint32(x)) + output_uint64(uint64(x)) output_float64(float(x)) def test_conv_from_u32(): @@ -108,7 +109,6 @@ def test_conv_from_u64(): def test_f64toi32(): for x in [ - float(i32_min()) - 1.0, float(i32_min()), float(i32_min()) + 1.0, -1.5, @@ -117,7 +117,6 @@ def test_f64toi32(): 1.5, float(i32_max()) - 1.0, float(i32_max()), - float(i32_max()) + 1.0 ]: output_int32(int32(x)) @@ -138,24 +137,17 @@ def test_f64toi64(): def test_f64tou32(): for x in [ - -1.5, - float(u32_min()) - 1.0, - -0.5, float(u32_min()), 0.5, float(u32_min()) + 1.0, 1.5, float(u32_max()) - 1.0, float(u32_max()), - float(u32_max()) + 1.0 ]: output_uint32(uint32(x)) def test_f64tou64(): for x in [ - -1.5, - float(u64_min()) - 1.0, - -0.5, float(u64_min()), 0.5, float(u64_min()) + 1.0, @@ -181,4 +173,4 @@ def run() -> int32: test_f64tou32() test_f64tou64() - return 0 \ No newline at end of file + return 0 From d9c180ed13df48fcfcc2b6b148741c09732742bb Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 13:36:51 +0800 Subject: [PATCH 40/80] [artiq] symbol_resolver: Fix support for np.bool_ -> bool decay --- nac3artiq/demo/numpy_primitives_decay.py | 29 ++++++++++++++++++++++++ nac3artiq/src/symbol_resolver.rs | 22 +++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 nac3artiq/demo/numpy_primitives_decay.py diff --git a/nac3artiq/demo/numpy_primitives_decay.py b/nac3artiq/demo/numpy_primitives_decay.py new file mode 100644 index 00000000..957d363f --- /dev/null +++ b/nac3artiq/demo/numpy_primitives_decay.py @@ -0,0 +1,29 @@ +from min_artiq import * +import numpy +from numpy import int32 + + +@nac3 +class NumpyBoolDecay: + core: KernelInvariant[Core] + np_true: KernelInvariant[bool] + np_false: KernelInvariant[bool] + np_int: KernelInvariant[int32] + np_float: KernelInvariant[float] + np_str: KernelInvariant[str] + + def __init__(self): + self.core = Core() + self.np_true = numpy.True_ + self.np_false = numpy.False_ + self.np_int = numpy.int32(0) + self.np_float = numpy.float64(0.0) + self.np_str = numpy.str_("") + + @kernel + def run(self): + pass + + +if __name__ == "__main__": + NumpyBoolDecay().run() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 8e9cd10c..0f9d57bd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -931,10 +931,13 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - obj.extract::().map_or_else( - |_| Ok(Err(format!("{obj} is not in the range of bool"))), - |_| Ok(Ok(extracted_ty)), - ) + if let Ok(_) = obj.extract::() { + Ok(Ok(extracted_ty)) + } else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::() { + Ok(Ok(extracted_ty)) + } else { + Ok(Err(format!("{obj} is not in the range of bool"))) + } } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of float64"))), @@ -974,10 +977,14 @@ impl InnerResolver { let val: u64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); @@ -1413,9 +1420,12 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.uint64 { let val: u64 = obj.extract()?; Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract()?; + Ok(SymbolValue::Bool(val)) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract()?; Ok(SymbolValue::Str(val)) From 2271b46b9631f8b3892666fbdcefe9061621dd8c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:31:17 +0800 Subject: [PATCH 41/80] [core] codegen/values/ndarray: Fix Vec allocation --- nac3core/src/codegen/values/ndarray/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index ad50d32e..b32a8f63 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -950,10 +950,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, /// This function is used generating strides for globally defined contiguous ndarrays. #[must_use] pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { - let mut strides = Vec::with_capacity(ndims as usize); + let mut strides = vec![0u64; ndims as usize]; let mut stride_product = 1u64; - for i in 0..ndims { - let axis = ndims - i - 1; + for axis in (0..ndims).rev() { strides[axis as usize] = stride_product * itemsize; stride_product *= shape[axis as usize]; } From 4e21def1a026caa2f860bf80ae47607b5a103cff Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:36:35 +0800 Subject: [PATCH 42/80] [artiq] symbol_resolver: Add missing promotion for host compilation Shape tuple is always in i32, so a zero-extension to i64 is necessary when assigning the shape tuple into the shape field of the ndarray. --- nac3artiq/src/symbol_resolver.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0f9d57bd..04d224cd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1131,7 +1131,10 @@ impl InnerResolver { super::CompileError::new_err(format!("Error getting element {i}: {e}")) })? .unwrap(); - let value = value.into_int_value(); + let value = ctx + .builder + .build_int_z_extend(value.into_int_value(), llvm_usize, "") + .unwrap(); Ok(value) }) .collect::, PyErr>>()?; From 3c5e247195f52b2296cca20bf54267a017e24b1e Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:39:24 +0800 Subject: [PATCH 43/80] [artiq] symbol_resolver: Use TargetData to get size of dtype dtype.size_of() may not return a constant value. --- nac3artiq/src/symbol_resolver.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 04d224cd..7b70cc49 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1213,8 +1213,16 @@ impl InnerResolver { data_global.set_initializer(&data); // Get the constant itemsize. - let itemsize = dtype.size_of().unwrap(); - let itemsize = itemsize.get_zero_extended_constant().unwrap(); + // + // NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size` + // will always return a constant size. + let itemsize = ctx + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data().get_store_size(&dtype)) + .unwrap(); + assert_ne!(itemsize, 0); // Create the strides needed for ndarray.strides let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); From 352c7c880b3604f78b84d31f86114ce8c79b2830 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:41:21 +0800 Subject: [PATCH 44/80] [artiq] symbol_resolver: Fix incorrect global type for ndarray.strides --- nac3artiq/src/symbol_resolver.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 7b70cc49..75600cc2 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1232,7 +1232,7 @@ impl InnerResolver { // create a global for ndarray.strides and initialize it let strides_global = ctx.module.add_global( - llvm_i8.array_type(ndims as u32), + llvm_usize.array_type(ndims as u32), Some(AddressSpace::default()), &format!("${id_str}.strides"), ); From eaaa194a87082ee5530f724b45c9822466aafe56 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 16:42:53 +0800 Subject: [PATCH 45/80] [artiq] symbol_resolver: Cast ndarray.{shape,strides} globals to usize* This is needed as ndarray.{shapes,strides} are ArrayValues, and so a GEP operation is required to convert them into pointers to their first elements. --- nac3artiq/src/symbol_resolver.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 75600cc2..232cd084 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1248,9 +1248,30 @@ impl InnerResolver { let ndarray_ndims = llvm_usize.const_int(ndims, false); + // calling as_pointer_value on shape and strides returns [i64 x ndims]* + // convert into i64* to conform with expected layout of ndarray + let ndarray_shape = shape_global.as_pointer_value(); + let ndarray_shape = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_shape, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; let ndarray_strides = strides_global.as_pointer_value(); + let ndarray_strides = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_strides, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; let ndarray = llvm_ndarray .as_base_type() From 8baf111734c6c29504a0e1804b8f8d56600896f7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 17:08:03 +0800 Subject: [PATCH 46/80] [meta] Apply clippy suggestions --- nac3artiq/src/codegen.rs | 6 +++--- nac3artiq/src/lib.rs | 4 ++-- nac3artiq/src/symbol_resolver.rs | 6 +++--- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/toplevel/test.rs | 2 +- nac3core/src/typecheck/function_check.rs | 4 ++-- nac3core/src/typecheck/type_error.rs | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 4 ++-- nac3ld/src/dwarf.rs | 8 ++++---- nac3ld/src/lib.rs | 2 +- 11 files changed, 21 insertions(+), 21 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 156ba23e..9baa0afe 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -162,7 +162,7 @@ impl<'a> ArtiqCodeGenerator<'a> { } } -impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { +impl CodeGenerator for ArtiqCodeGenerator<'_> { fn get_name(&self) -> &str { &self.name } @@ -1505,7 +1505,7 @@ pub fn call_rtio_log_impl<'ctx>( /// Generates a call to `core_log`. pub fn gen_core_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, @@ -1522,7 +1522,7 @@ pub fn gen_core_log<'ctx>( /// Generates a call to `rtio_log`. pub fn gen_rtio_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ca2f2f15..1601d4a4 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -330,7 +330,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_core_log(ctx, &obj, fun, &args, generator)?; + gen_core_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), @@ -360,7 +360,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_rtio_log(ctx, &obj, fun, &args, generator)?; + gen_rtio_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 232cd084..1e99ac20 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -931,9 +931,9 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - if let Ok(_) = obj.extract::() { - Ok(Ok(extracted_ty)) - } else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::() { + if obj.extract::().is_ok() + || obj.call_method("__bool__", (), None)?.extract::().is_ok() + { Ok(Ok(extracted_ty)) } else { Ok(Err(format!("{obj} is not in the range of bool"))) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4b83e63e..e74e7ad8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -79,7 +79,7 @@ pub fn get_subst_key( .join(", ") } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Builds a sequence of `getelementptr` and `load` instructions which stores the value of a /// struct field into an LLVM value. pub fn build_gep_and_load( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 4e83d530..e74071bc 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -228,7 +228,7 @@ pub struct CodeGenContext<'ctx, 'a> { pub current_loc: Location, } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl CodeGenContext<'_, '_> { /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 1f33a4ba..6a83632e 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -807,7 +807,7 @@ struct TypeToStringFolder<'a> { unifier: &'a mut Unifier, } -impl<'a> Fold> for TypeToStringFolder<'a> { +impl Fold> for TypeToStringFolder<'_> { type TargetU = String; type Error = String; fn map_user(&mut self, user: Option) -> Result { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index ed801a17..2e655d11 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -15,7 +15,7 @@ use super::{ }; use crate::toplevel::helper::PrimDef; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)])) @@ -94,7 +94,7 @@ impl<'a> Inferencer<'a> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) - && !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id()) + && ty.obj_id(self.unifier).is_none_or(|id| id != PrimDef::List.id()) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { return Err(HashSet::from([format!( diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index 3cf5bdc4..b144f8a9 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -94,7 +94,7 @@ fn loc_to_str(loc: Option) -> String { } } -impl<'a> Display for DisplayTypeError<'a> { +impl Display for DisplayTypeError<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { use TypeErrorKind::*; let mut notes = Some(HashMap::new()); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 87692114..742fa197 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located) { } } -impl<'a> Fold<()> for Inferencer<'a> { +impl Fold<()> for Inferencer<'_> { type TargetU = Option; type Error = InferenceError; @@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> { type InferenceResult = Result; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { /// Constrain a <: b /// Currently implemented as unification fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { diff --git a/nac3ld/src/dwarf.rs b/nac3ld/src/dwarf.rs index e85a4e40..4bcccd32 100644 --- a/nac3ld/src/dwarf.rs +++ b/nac3ld/src/dwarf.rs @@ -30,7 +30,7 @@ pub struct DwarfReader<'a> { pub virt_addr: u32, } -impl<'a> DwarfReader<'a> { +impl DwarfReader<'_> { pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { DwarfReader { slice, virt_addr } } @@ -113,7 +113,7 @@ pub struct DwarfWriter<'a> { pub offset: usize, } -impl<'a> DwarfWriter<'a> { +impl DwarfWriter<'_> { pub fn new(slice: &mut [u8]) -> DwarfWriter { DwarfWriter { slice, offset: 0 } } @@ -375,7 +375,7 @@ pub struct FDE_Records<'a> { available: usize, } -impl<'a> Iterator for FDE_Records<'a> { +impl Iterator for FDE_Records<'_> { type Item = (u32, u32); fn next(&mut self) -> Option { @@ -423,7 +423,7 @@ pub struct EH_Frame_Hdr<'a> { fdes: Vec<(u32, u32)>, } -impl<'a> EH_Frame_Hdr<'a> { +impl EH_Frame_Hdr<'_> { /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// /// Load address is not known at this point. diff --git a/nac3ld/src/lib.rs b/nac3ld/src/lib.rs index 73e065dd..8ef8653a 100644 --- a/nac3ld/src/lib.rs +++ b/nac3ld/src/lib.rs @@ -159,7 +159,7 @@ struct SymbolTableReader<'a> { strtab: &'a [u8], } -impl<'a> SymbolTableReader<'a> { +impl SymbolTableReader<'_> { pub fn find_index_by_name(&self, sym_name: &[u8]) -> Option { self.symtab.iter().position(|sym| { if let Ok(dynsym_name) = name_starting_at_slice(self.strtab, sym.st_name as usize) { From d1dcfa19ff5696a8a109756398d0fcb3fc970536 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 14:55:33 +0800 Subject: [PATCH 47/80] CodeGenerator: Add with_target_machine factory function Allows creating CodeGenerator with the LLVM target machine to infer the expected type for size_t. --- nac3artiq/src/codegen.rs | 18 +++++++++++++++--- nac3artiq/src/lib.rs | 19 ++++++++++++++----- nac3core/src/codegen/generator.rs | 19 ++++++++++++++----- nac3core/src/codegen/test.rs | 10 ++++++---- nac3standalone/src/main.rs | 8 +++++++- 5 files changed, 56 insertions(+), 18 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 9baa0afe..daf539fd 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -29,6 +29,7 @@ use nac3core::{ inkwell::{ context::Context, module::Linkage, + targets::TargetMachine, types::{BasicType, IntType}, values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, @@ -87,13 +88,13 @@ pub struct ArtiqCodeGenerator<'a> { impl<'a> ArtiqCodeGenerator<'a> { pub fn new( name: String, - size_t: u32, + size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), ) -> ArtiqCodeGenerator<'a> { - assert!(size_t == 32 || size_t == 64); + assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { name, - size_t, + size_t: size_t.get_bit_width(), name_counter: 0, start: None, end: None, @@ -102,6 +103,17 @@ impl<'a> ArtiqCodeGenerator<'a> { } } + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + timeline: &'a (dyn TimeFns + Sync), + ) -> ArtiqCodeGenerator<'a> { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize, timeline) + } + /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the /// position of the timeline to the initial timeline position before entering the `parallel` /// block. diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 1601d4a4..9a69c1a3 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -703,14 +703,18 @@ impl Nac3 { let buffer = buffer.as_slice().into(); membuffer.lock().push(buffer); }))); - let size_t = context - .ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None) - .get_bit_width(); let num_threads = if is_multithreaded() { 4 } else { 1 }; let thread_names: Vec = (0..num_threads).map(|_| "main".to_string()).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) + .map(|s| { + Box::new(ArtiqCodeGenerator::with_target_machine( + s.to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + )) + }) .collect(); let membuffer = membuffers.clone(); @@ -719,8 +723,13 @@ impl Nac3 { let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); - let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns); let context = Context::create(); + let mut generator = ArtiqCodeGenerator::with_target_machine( + "main".to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); module.set_data_layout(&target_machine.get_target_data().get_data_layout()); diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index be007c2a..a416f10a 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -1,5 +1,6 @@ use inkwell::{ context::Context, + targets::TargetMachine, types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue, PointerValue}, }; @@ -270,19 +271,27 @@ pub struct DefaultCodeGenerator { impl DefaultCodeGenerator { #[must_use] - pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator { - assert!(matches!(size_t, 32 | 64)); - DefaultCodeGenerator { name, size_t } + pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator { + assert!(matches!(size_t.get_bit_width(), 32 | 64)); + DefaultCodeGenerator { name, size_t: size_t.get_bit_width() } + } + + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + ) -> DefaultCodeGenerator { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize) } } impl CodeGenerator for DefaultCodeGenerator { - /// Returns the name for this [`CodeGenerator`]. fn get_name(&self) -> &str { &self.name } - /// Returns an LLVM integer type representing `size_t`. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { // it should be unsigned, but we don't really need unsigned and this could save us from // having to do a bit cast... diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 2701e138..48bef5f2 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -97,6 +97,7 @@ fn test_primitives() { "}; let statements = parse_program(source, FileName::default()).unwrap(); + let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; @@ -107,7 +108,7 @@ fn test_primitives() { Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) as Arc; - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; let signature = FunSignature { args: vec![ FuncArg { @@ -260,6 +261,7 @@ fn test_simple_call() { "}; let statements_2 = parse_program(source_2, FileName::default()).unwrap(); + let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; @@ -307,7 +309,7 @@ fn test_simple_call() { unreachable!() } - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), @@ -439,7 +441,7 @@ fn test_simple_call() { #[test] fn test_classes_list_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); @@ -459,7 +461,7 @@ fn test_classes_range_type_new() { #[test] fn test_classes_ndarray_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 2fce5d16..d54e08e7 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -456,7 +456,13 @@ fn main() { membuffer.lock().push(buffer); }))); let threads = (0..threads) - .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t))) + .map(|i| { + Box::new(DefaultCodeGenerator::with_target_machine( + format!("module{i}"), + &context, + &target_machine, + )) + }) .collect(); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); From 3ebd4ba5d14277a00e953f4dc354fb2fe230156b Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 14:56:22 +0800 Subject: [PATCH 48/80] [core] codegen: Add assertion verifying size_t is compatible --- nac3core/src/codegen/mod.rs | 11 +++++++++++ nac3core/src/codegen/test.rs | 8 ++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e74071bc..28b9a654 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -989,6 +989,17 @@ pub fn gen_func_impl< debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), }; + let target_llvm_usize = context.ptr_sized_int_type( + ®istry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(), + None, + ); + let generator_llvm_usize = generator.get_size_type(context); + assert_eq!( + generator_llvm_usize, + target_llvm_usize, + "CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})", + ); + let loc = code_gen_context.debug_info.0.create_debug_location( context, row as u32, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 48bef5f2..6518d858 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -98,7 +98,7 @@ fn test_primitives() { let statements = parse_program(source, FileName::default()).unwrap(); let context = inkwell::context::Context::create(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -108,7 +108,7 @@ fn test_primitives() { Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) as Arc; - let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let signature = FunSignature { args: vec![ FuncArg { @@ -262,7 +262,7 @@ fn test_simple_call() { let statements_2 = parse_program(source_2, FileName::default()).unwrap(); let context = inkwell::context::Context::create(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -309,7 +309,7 @@ fn test_simple_call() { unreachable!() } - let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), From f8530e0ef694c7d8ac7690ecd9eab838b5d59831 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 20:26:15 +0800 Subject: [PATCH 49/80] [core] codegen: Add CodeGenContext::get_size_type Convenience method for getting the `size_t` LLVM type without the use of `CodeGenerator`. --- nac3core/src/codegen/generator.rs | 3 +++ nac3core/src/codegen/mod.rs | 25 +++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index a416f10a..620ede0e 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -19,6 +19,9 @@ pub trait CodeGenerator { fn get_name(&self) -> &str; /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. + /// + /// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is + /// equivalent to this function in a more concise syntax. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 28b9a654..797a62be 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,4 +1,5 @@ use std::{ + cell::OnceCell, collections::{HashMap, HashSet}, sync::{ atomic::{AtomicBool, Ordering}, @@ -19,7 +20,7 @@ use inkwell::{ module::Module, passes::PassBuilderOptions, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, - types::{AnyType, BasicType, BasicTypeEnum}, + types::{AnyType, BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -226,14 +227,33 @@ pub struct CodeGenContext<'ctx, 'a> { /// The current source location. pub current_loc: Location, + + /// The cached type of `size_t`. + llvm_usize: OnceCell>, } -impl CodeGenContext<'_, '_> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some() } + + /// Returns a [`IntType`] representing `size_t` for the compilation target as specified by + /// [`self.registry`][WorkerRegistry]. + pub fn get_size_type(&self) -> IntType<'ctx> { + *self.llvm_usize.get_or_init(|| { + self.ctx.ptr_sized_int_type( + &self + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data()) + .unwrap(), + None, + ) + }) + } } type Fp = Box; @@ -987,6 +1007,7 @@ pub fn gen_func_impl< need_sret: has_sret, current_loc: Location::default(), debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), + llvm_usize: OnceCell::default(), }; let target_llvm_usize = context.ptr_sized_int_type( From c59fd286ff166c7f1e20f8e8fbaf80666f815faa Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 20:43:57 +0800 Subject: [PATCH 50/80] [artiq] Move `get_llvm_*` to Isa, use `TargetMachine` to infer size_t --- nac3artiq/src/lib.rs | 113 +++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9a69c1a3..d35e66d1 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -78,14 +78,62 @@ enum Isa { } impl Isa { - /// Returns the number of bits in `size_t` for the [`Isa`]. - fn get_size_type(self) -> u32 { - if self == Isa::Host { - 64u32 - } else { - 32u32 + /// Returns the [`TargetTriple`] used for compiling to this ISA. + pub fn get_llvm_target_triple(self) -> TargetTriple { + match self { + Isa::Host => TargetMachine::get_default_triple(), + Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), + Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), } } + + /// Returns the [`String`] representing the target CPU used for compiling to this ISA. + pub fn get_llvm_target_cpu(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_name().to_string(), + Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), + Isa::CortexA9 => "cortex-a9".to_string(), + } + } + + /// Returns the [`String`] representing the target features used for compiling to this ISA. + pub fn get_llvm_target_features(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_features().to_string(), + Isa::RiscV32G => "+a,+m,+f,+d".to_string(), + Isa::RiscV32IMA => "+a,+m".to_string(), + Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), + } + } + + /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine + /// options used for compiling to this ISA. + pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions { + CodeGenTargetMachineOptions { + triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(), + cpu: self.get_llvm_target_cpu(), + features: self.get_llvm_target_features(), + reloc_mode: RelocMode::PIC, + ..CodeGenTargetMachineOptions::from_host() + } + } + + /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this + /// ISA. + pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine { + self.get_llvm_target_options() + .create_target_machine(opt_level) + .expect("couldn't create target machine") + } + + /// Returns the number of bits in `size_t` for this ISA. + fn get_size_type(self, ctx: &Context) -> u32 { + ctx.ptr_sized_int_type( + &self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(), + None, + ) + .get_bit_width() + } } #[derive(Clone)] @@ -378,7 +426,7 @@ impl Nac3 { py: Python, link_fn: &dyn Fn(&Module) -> PyResult, ) -> PyResult { - let size_t = self.isa.get_size_type(); + let size_t = self.isa.get_size_type(&Context::create()); let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( self.builtins.clone(), Self::get_lateinit_builtins(), @@ -848,52 +896,10 @@ impl Nac3 { link_fn(&main) } - /// Returns the [`TargetTriple`] used for compiling to [isa]. - fn get_llvm_target_triple(isa: Isa) -> TargetTriple { - match isa { - Isa::Host => TargetMachine::get_default_triple(), - Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), - Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), - } - } - - /// Returns the [`String`] representing the target CPU used for compiling to [isa]. - fn get_llvm_target_cpu(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_name().to_string(), - Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), - Isa::CortexA9 => "cortex-a9".to_string(), - } - } - - /// Returns the [`String`] representing the target features used for compiling to [isa]. - fn get_llvm_target_features(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_features().to_string(), - Isa::RiscV32G => "+a,+m,+f,+d".to_string(), - Isa::RiscV32IMA => "+a,+m".to_string(), - Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), - } - } - - /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine - /// options used for compiling to [isa]. - fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions { - CodeGenTargetMachineOptions { - triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(), - cpu: Nac3::get_llvm_target_cpu(isa), - features: Nac3::get_llvm_target_features(isa), - reloc_mode: RelocMode::PIC, - ..CodeGenTargetMachineOptions::from_host() - } - } - /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the - /// target [isa]. + /// target [ISA][isa]. fn get_llvm_target_machine(&self) -> TargetMachine { - Nac3::get_llvm_target_options(self.isa) - .create_target_machine(self.llvm_options.opt_level) - .expect("couldn't create target machine") + self.isa.create_llvm_target_machine(self.llvm_options.opt_level) } } @@ -1001,7 +1007,8 @@ impl Nac3 { Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, }; - let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type()); + let (primitive, _) = + TopLevelComposer::make_primitives(isa.get_size_type(&Context::create())); let builtins = vec![ ( "now_mu".into(), @@ -1150,7 +1157,7 @@ impl Nac3 { deferred_eval_store: DeferredEvaluationStore::new(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, - target: Nac3::get_llvm_target_options(isa), + target: isa.get_llvm_target_options(), }, }) } From bd66fe48d8acb011ba3b3dc1a6c507db91b56fb6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 21:05:27 +0800 Subject: [PATCH 51/80] [core] codegen: Refactor to use CodeGenContext::get_size_type Simplifies a lot of API usage. --- nac3artiq/src/codegen.rs | 10 +-- nac3artiq/src/symbol_resolver.rs | 4 +- nac3core/src/codegen/builtin_fns.rs | 18 ++-- nac3core/src/codegen/expr.rs | 48 +++++------ nac3core/src/codegen/irrt/list.rs | 4 +- nac3core/src/codegen/irrt/mod.rs | 8 +- nac3core/src/codegen/irrt/ndarray/array.rs | 18 ++-- nac3core/src/codegen/irrt/ndarray/basic.rs | 67 ++++++--------- .../src/codegen/irrt/ndarray/broadcast.rs | 9 +- nac3core/src/codegen/irrt/ndarray/indexing.rs | 2 +- nac3core/src/codegen/irrt/ndarray/iter.rs | 17 ++-- nac3core/src/codegen/irrt/ndarray/matmul.rs | 5 +- nac3core/src/codegen/irrt/ndarray/reshape.rs | 3 +- .../src/codegen/irrt/ndarray/transpose.rs | 4 +- nac3core/src/codegen/irrt/string.rs | 7 +- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 18 ++-- nac3core/src/codegen/stmt.rs | 14 ++- nac3core/src/codegen/types/list.rs | 6 +- nac3core/src/codegen/types/ndarray/array.rs | 6 +- .../src/codegen/types/ndarray/contiguous.rs | 2 +- nac3core/src/codegen/types/ndarray/map.rs | 10 +-- nac3core/src/codegen/types/ndarray/mod.rs | 10 +-- nac3core/src/codegen/types/tuple.rs | 2 +- nac3core/src/codegen/values/array.rs | 2 +- nac3core/src/codegen/values/list.rs | 11 +-- .../src/codegen/values/ndarray/broadcast.rs | 6 +- .../src/codegen/values/ndarray/contiguous.rs | 4 +- .../src/codegen/values/ndarray/indexing.rs | 2 +- nac3core/src/codegen/values/ndarray/matmul.rs | 4 +- nac3core/src/codegen/values/ndarray/mod.rs | 85 ++++++------------- nac3core/src/codegen/values/ndarray/nditer.rs | 18 ++-- nac3core/src/codegen/values/ndarray/shape.rs | 2 +- nac3core/src/codegen/values/ndarray/view.rs | 8 +- nac3core/src/toplevel/builtins.rs | 6 +- 35 files changed, 176 insertions(+), 266 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index daf539fd..cb75606d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -471,7 +471,7 @@ fn format_rpc_arg<'ctx>( // libproto_artiq: NDArray = [data[..], dim_sz[..]] let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); @@ -556,7 +556,7 @@ fn format_rpc_ret<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { @@ -697,7 +697,7 @@ fn format_rpc_ret<'ctx>( // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let num_elements = ndarray.size(generator, ctx); + let num_elements = ndarray.size(ctx); let expected_ndarray_nbytes = ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); @@ -809,7 +809,7 @@ fn rpc_codegen_callback_fn<'ctx>( ) -> Result>, String> { let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); - let size_type = generator.get_size_type(ctx.ctx); + let size_type = ctx.get_size_type(); let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); @@ -1167,7 +1167,7 @@ fn polymorphic_print<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let suffix = suffix.unwrap_or_default(); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 1e99ac20..d9768669 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1007,7 +1007,7 @@ impl InnerResolver { } _ => unreachable!("must be list"), }; - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let ty = if len == 0 && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) { @@ -1096,7 +1096,7 @@ impl InnerResolver { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); let dtype = llvm_ndarray.element_type(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 9d368070..96f8c700 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -64,7 +64,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) .map_value(arg.into_pointer_value(), None); ctx.builder - .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") + .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") .unwrap() } @@ -835,7 +835,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { @@ -870,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") + .build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "") .unwrap(); ctx.make_assert( @@ -1676,7 +1676,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_qr"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1728,7 +1728,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_svd"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1821,7 +1821,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_pinv"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1862,7 +1862,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "sp_linalg_lu"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1915,7 +1915,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) @@ -1968,7 +1968,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e74e7ad8..00290d34 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -165,7 +165,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); ty.const_named_struct(&[str_ptr, size.into()]).into() } @@ -318,7 +318,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str); let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); @@ -820,7 +820,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result>, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let id; @@ -1020,7 +1020,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( } let is_vararg = args.iter().any(|arg| arg.is_vararg); if is_vararg { - params.push(generator.get_size_type(ctx.ctx).into()); + params.push(ctx.get_size_type().into()); } let fun_ty = match ret_type { Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg), @@ -1128,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( return Ok(None); }; let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero_size_t = size_t.const_zero(); let zero_32 = int32.const_zero(); @@ -1258,15 +1258,13 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } // Emits the content of `cont_bb` - let emit_cont_bb = - |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { - ctx.builder.position_at_end(cont_bb); - list.store_size( - ctx, - generator, - ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), - ); - }; + let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| { + ctx.builder.position_at_end(cont_bb); + list.store_size( + ctx, + ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), + ); + }; for cond in ifs { let result = if let Some(v) = generator.gen_expr(ctx, cond)? { @@ -1274,7 +1272,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } else { // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // no element matches the predicate - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1287,7 +1285,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let Some(elem) = generator.gen_expr(ctx, elt)? else { // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1304,7 +1302,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( .unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap(); - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); Ok(Some(list.as_base_value().into())) } @@ -1350,7 +1348,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); if op.variant == BinopVariant::AugAssign { todo!("Augmented assignment operators not implemented for lists") @@ -1972,7 +1970,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let rhs = rhs.into_struct_value(); let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); @@ -2000,7 +1998,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); + let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { ctx.builder.build_not(result, "").unwrap() } else { @@ -2010,7 +2008,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .iter() .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let gen_list_cmpop = |generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>| @@ -2375,7 +2373,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ) -> Result>, String> { ctx.current_loc = expr.location; let int32 = ctx.ctx.i32_type(); - let usize = generator.get_size_type(ctx.ctx); + let usize = ctx.get_size_type(); let zero = int32.const_int(0, false); let loc = ctx.debug_info.0.create_debug_location( @@ -2480,7 +2478,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { Some(elements[0].get_type()) }; - let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); + let length = ctx.get_size_type().const_int(elements.len() as u64, false); let arr_str_ptr = if let Some(ty) = ty { ListType::new(generator, ctx.ctx, ty).construct( generator, @@ -3009,7 +3007,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let raw_index = ctx .builder - .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") + .build_int_s_extend(raw_index, ctx.get_size_type(), "sext") .unwrap(); // handle negative index let is_negative = ctx @@ -3017,7 +3015,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, raw_index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index 2c57f8e7..c01e2cb6 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,7 +24,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let llvm_i32 = ctx.ctx.i32_type(); @@ -168,7 +168,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.position_at_end(update_bb); let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); - dest_arr.store_size(ctx, generator, new_len); + dest_arr.store_size(ctx, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 4cacdccb..87391780 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -68,13 +68,9 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. #[must_use] -pub fn get_usize_dependent_function_name( - generator: &G, - ctx: &CodeGenContext<'_, '_>, - name: &str, -) -> String { +pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String { let mut name = name.to_owned(); - match generator.get_size_type(ctx.ctx).get_bit_width() { + match ctx.get_size_type().get_bit_width() { 32 => {} 64 => name.push_str("64"), bit_width => { diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs index 931b66cb..5e9c0f0b 100644 --- a/nac3core/src/codegen/irrt/ndarray/array.rs +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -21,7 +21,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato ndims: IntValue<'ctx>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(ndims.get_type(), llvm_usize); assert_eq!( @@ -29,11 +29,8 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_array_set_and_validate_list_shape", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape"); infer_and_call_function( ctx, @@ -55,19 +52,14 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato /// - `ndarray.ndims`: Must be initialized. /// - `ndarray.shape`: Must be initialized. /// - `ndarray.data`: Must be allocated and contiguous. -pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>, ndarray: NDArrayValue<'ctx>, ) { assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_array_write_list_to_array", - ); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index d11c9b8d..aa792b15 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ctx: &CodeGenContext<'ctx, '_>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -28,11 +28,8 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_shape_no_negative", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); create_and_call_function( ctx, @@ -57,7 +54,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -69,11 +66,8 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_output_shape_same", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); create_and_call_function( ctx, @@ -94,15 +88,14 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an /// `ndarray`, corresponding to the value of `ndarray.size`. -pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_size<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); create_and_call_function( ctx, @@ -120,15 +113,14 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the /// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. -pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_nbytes<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); create_and_call_function( ctx, @@ -146,15 +138,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of /// the `ndarray`, corresponding to the value of `ndarray.__len__`. -pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_len<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); create_and_call_function( ctx, @@ -171,15 +162,14 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_ndarray_is_c_contiguous`. /// /// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. -pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); create_and_call_function( ctx, @@ -196,20 +186,19 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_ndarray_get_nth_pelement`. /// /// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. -pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, index: IntValue<'ctx>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); assert_eq!(index.get_type(), llvm_usize); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( ctx, @@ -236,7 +225,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_ndarray = ndarray.get_type().as_base_type(); @@ -245,8 +234,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized llvm_usize.into() ); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); create_and_call_function( ctx, @@ -266,15 +254,13 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized /// Generates a call to `__nac3_ndarray_set_strides_by_shape`. /// /// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. -pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); create_and_call_function( ctx, @@ -291,13 +277,12 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( /// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number /// of elements in `src_ndarray` must be greater than or equal to the number of elements in /// `dst_ndarray`. -pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_copy_data<'ctx>( ctx: &CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index cb1ecd4c..fceba25f 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -20,13 +20,12 @@ use crate::codegen::{ /// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. /// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. /// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. -pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_broadcast_to<'ctx>( ctx: &CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to"); infer_and_call_function( ctx, &name, @@ -53,7 +52,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(num_shape_entries.get_type(), llvm_usize); assert!(ShapeEntryType::is_type( @@ -65,7 +64,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( assert_eq!(dst_ndims.get_type(), llvm_usize); assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 3e2c908d..df5b27de 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -17,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 47cd5b29..ad90178c 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -25,7 +25,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ndarray: NDArrayValue<'ctx>, indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -33,7 +33,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.into() ); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); create_and_call_function( ctx, @@ -53,12 +53,11 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` /// object. -pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_nditer_has_element<'ctx>( ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ) -> IntValue<'ctx> { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element"); infer_and_call_function( ctx, @@ -75,12 +74,8 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_nditer_next`. /// /// Moves `iter` to point to the next element. -pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - iter: NDIterValue<'ctx>, -) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next"); +pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs index 551cb7c7..0df774fe 100644 --- a/nac3core/src/codegen/irrt/ndarray/matmul.rs +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!( BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), @@ -43,8 +43,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized llvm_usize.into() ); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs index 32de2fa1..66cbf132 100644 --- a/nac3core/src/codegen/irrt/ndarray/reshape.rs +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -18,14 +18,13 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera new_ndims: IntValue<'ctx>, new_shape: ArraySliceValue<'ctx>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(size.get_type(), llvm_usize); assert_eq!(new_ndims.get_type(), llvm_usize); assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name( - generator, ctx, "__nac3_ndarray_reshape_resolve_and_check_new_shape", ); diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs index 57661fa7..6d152dd1 100644 --- a/nac3core/src/codegen/irrt/ndarray/transpose.rs +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -23,12 +23,12 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( dst_ndarray: NDArrayValue<'ctx>, axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index 6ee40e45..e2fd8c09 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -2,11 +2,10 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; use super::get_usize_dependent_function_name; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::CodeGenContext; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. -pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_string_eq<'ctx>( ctx: &CodeGenContext<'ctx, '_>, str1_ptr: PointerValue<'ctx>, str1_len: IntValue<'ctx>, @@ -15,7 +14,7 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq"); + let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { ctx.module.add_function( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 797a62be..f7483ef8 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1212,7 +1212,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let align_ty = align_ty.into(); let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6c16be9f..6700af45 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -207,7 +207,7 @@ pub fn gen_ndarray_eye<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); - let llvm_usize = generator.get_size_type(context.ctx); + let llvm_usize = context.get_size_type(); let llvm_dtype = context.get_llvm_type(generator, dtype); let nrows = context @@ -244,7 +244,7 @@ pub fn gen_ndarray_identity<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); - let llvm_usize = generator.get_size_type(context.ctx); + let llvm_usize = context.get_size_type(); let llvm_dtype = context.get_llvm_type(generator, dtype); let n = context @@ -325,8 +325,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); // Check shapes. - let a_size = a.size(generator, ctx); - let b_size = b.size(generator, ctx); + let a_size = a.size(ctx); + let b_size = b.size(ctx); let same_shape = ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( @@ -353,9 +353,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); Ok((a_iter, b_iter)) }, - |generator, ctx, (a_iter, _b_iter)| { + |_, ctx, (a_iter, _b_iter)| { // Only a_iter drives the condition, b_iter should have the same status. - Ok(a_iter.has_element(generator, ctx)) + Ok(a_iter.has_element(ctx)) }, |_, ctx, _hooks, (a_iter, b_iter)| { let a_scalar = a_iter.get_scalar(ctx); @@ -385,9 +385,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - |generator, ctx, (a_iter, b_iter)| { - a_iter.next(generator, ctx); - b_iter.next(generator, ctx); + |_, ctx, (a_iter, b_iter)| { + a_iter.next(ctx); + b_iter.next(ctx); Ok(()) }, ) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 2a3bd066..c3274057 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -306,7 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { // Handle list item assignment - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; let target = generator @@ -367,10 +367,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, key_ty)? .into_int_value(); - let index = ctx - .builder - .build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext") - .unwrap(); + let index = + ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap(); // handle negative index let is_negative = ctx @@ -378,7 +376,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); @@ -460,7 +458,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let target = broadcast_result.ndarrays[0]; let value = broadcast_result.ndarrays[1]; - target.copy_data_from(generator, ctx, value); + target.copy_data_from(ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); @@ -484,7 +482,7 @@ pub fn gen_for( let var_assignment = ctx.var_assignment.clone(); let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero = int32.const_zero(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let body_bb = ctx.ctx.append_basic_block(current, "for.body"); diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 337d049c..9ea4acaa 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -152,7 +152,7 @@ impl<'ctx> ListType<'ctx> { _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), }; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { None } else { @@ -273,7 +273,7 @@ impl<'ctx> ListType<'ctx> { } let plist = self.alloca_var(generator, ctx, name); - plist.store_size(ctx, generator, len); + plist.store_size(ctx, len); let item = self.item.unwrap_or(self.llvm_usize.into()); plist.create_data(ctx, item, None); @@ -300,7 +300,7 @@ impl<'ctx> ListType<'ctx> { ) -> >::Value { let plist = self.alloca_var(generator, ctx, name); - plist.store_size(ctx, generator, self.llvm_usize.const_zero()); + plist.store_size(ctx, self.llvm_usize.const_zero()); plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); plist diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 0f30f0eb..b0c9d637 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -67,9 +67,7 @@ impl<'ctx> NDArrayType<'ctx> { unsafe { ndarray.create_data(generator, ctx) }; // Copy all contents from the list. - irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( - generator, ctx, list_value, ndarray, - ); + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray); ndarray } @@ -116,7 +114,7 @@ impl<'ctx> NDArrayType<'ctx> { } // Set strides, the `data` is contiguous - ndarray.set_strides_contiguous(generator, ctx); + ndarray.set_strides_contiguous(ctx); ndarray } else { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index e5fb8cdc..f4a8b73d 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -140,7 +140,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } } diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index bf82b4da..6fdd9e12 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -86,10 +86,10 @@ impl<'ctx> NDArrayType<'ctx> { .collect_vec(); Ok((nditer, other_nditers)) }, - |generator, ctx, (out_nditer, _in_nditers)| { + |_, ctx, (out_nditer, _in_nditers)| { // We can simply use `out_nditer`'s `has_element()`. // `in_nditers`' `has_element()`s should return the same value. - Ok(out_nditer.has_element(generator, ctx)) + Ok(out_nditer.has_element(ctx)) }, |generator, ctx, _hooks, (out_nditer, in_nditers)| { // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, @@ -104,10 +104,10 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) }, - |generator, ctx, (out_nditer, in_nditers)| { + |_, ctx, (out_nditer, in_nditers)| { // Advance all iterators - out_nditer.next(generator, ctx); - in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + out_nditer.next(ctx); + in_nditers.iter().for_each(|nditer| nditer.next(ctx)); Ok(()) }, )?; diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 353ace33..a7bcb7ef 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -158,7 +158,7 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let ndims = extract_ndims(&ctx.unifier, ndims); NDArrayType { @@ -259,9 +259,9 @@ impl<'ctx> NDArrayType<'ctx> { .builder .build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") .unwrap(); - ndarray.store_itemsize(ctx, generator, itemsize); + ndarray.store_itemsize(ctx, itemsize); - ndarray.store_ndims(ctx, generator, ndims); + ndarray.store_ndims(ctx, ndims); ndarray.create_shape(ctx, self.llvm_usize, ndims); ndarray.create_strides(ctx, self.llvm_usize, ndims); @@ -307,7 +307,7 @@ impl<'ctx> NDArrayType<'ctx> { let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); @@ -342,7 +342,7 @@ impl<'ctx> NDArrayType<'ctx> { let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index ccb63b4a..947f95ad 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -52,7 +52,7 @@ impl<'ctx> TupleType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ty: Type, ) -> Self { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Sanity check on object type. let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index b756f278..9f6652b3 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -418,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index c497f8f8..08d2b6b5 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -97,13 +97,8 @@ impl<'ctx> ListValue<'ctx> { } /// Stores the `size` of this `list` into this instance. - pub fn store_size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - size: IntValue<'ctx>, - ) { - debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { + debug_assert_eq!(size.get_type(), ctx.get_size_type()); self.len_field(ctx).set(ctx, self.value, size, self.name); } @@ -213,7 +208,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 1b99f464..b145746e 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -112,7 +112,7 @@ impl<'ctx> NDArrayValue<'ctx> { target_shape.base_ptr(ctx, generator), ); - irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); + irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray); broadcast_ndarray } } @@ -146,7 +146,7 @@ fn broadcast_shapes<'ctx, G, Shape>( Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); assert!(in_shape_entries @@ -199,7 +199,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> BroadcastAllResult<'ctx, G> { assert!(!ndarrays.is_empty()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Infer the broadcast output ndims. let broadcast_ndims_int = diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 8eb700b9..52082df6 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -130,7 +130,7 @@ impl<'ctx> NDArrayValue<'ctx> { gen_if_callback( generator, ctx, - |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); @@ -184,7 +184,7 @@ impl<'ctx> NDArrayValue<'ctx> { // Copy shape and update strides let shape = carray.load_shape(ctx); ndarray.copy_shape_from_array(generator, ctx, shape); - ndarray.set_strides_contiguous(generator, ctx); + ndarray.set_strides_contiguous(ctx); // Share data let data = carray.load_data(ctx); diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3821f232..1a96522b 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -245,7 +245,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) + SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type()) .alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index f802c0c0..a24316b4 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -35,7 +35,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); // Deduce ndims of the result of matmul. @@ -315,7 +315,7 @@ impl<'ctx> NDArrayValue<'ctx> { let result_shape = result.shape(); out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); - out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray.copy_data_from(ctx, result); out_ndarray } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index b32a8f63..595345e8 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -81,13 +81,8 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { + debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); let pndims = self.ptr_to_ndims(ctx); ctx.builder.build_store(pndims, ndims).unwrap(); @@ -104,13 +99,8 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the size of each element `itemsize` into this instance. - pub fn store_itemsize( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - itemsize: IntValue<'ctx>, - ) { - debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { + debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); } @@ -205,12 +195,12 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) { - let nbytes = self.nbytes(generator, ctx); + let nbytes = self.nbytes(ctx); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None); self.store_data(ctx, data); - self.set_strides_contiguous(generator, ctx); + self.set_strides_contiguous(ctx); } /// Returns a proxy object to the field storing the data of this `NDArray`. @@ -284,52 +274,32 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Get the `np.size()` of this ndarray. - pub fn size( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self) + pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_size(ctx, *self) } /// Get the `ndarray.nbytes` of this ndarray. - pub fn nbytes( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self) + pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self) } /// Get the `len()` of this ndarray. - pub fn len( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self) + pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_len(ctx, *self) } /// Check if this ndarray is C-contiguous. /// /// See NumPy's `flags["C_CONTIGUOUS"]`: - pub fn is_c_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self) + pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self) } /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// /// Update the ndarray's strides to make the ndarray contiguous. - pub fn set_strides_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) { - irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); + pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self); } /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and @@ -347,7 +317,7 @@ impl<'ctx> NDArrayValue<'ctx> { let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { clone.create_data(generator, ctx) }; - clone.copy_data_from(generator, ctx, *self); + clone.copy_data_from(ctx, *self); clone } @@ -357,14 +327,9 @@ impl<'ctx> NDArrayValue<'ctx> { /// do not matter. The copying order is determined by how their flattened views look. /// /// Panics if the `dtype`s of ndarrays are different. - pub fn copy_data_from( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - src: NDArrayValue<'ctx>, - ) { + pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) { assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); - irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); + irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self); } /// Fill the ndarray with a scalar. @@ -468,7 +433,7 @@ impl<'ctx> NDArrayValue<'ctx> { ) -> Option> { if self.is_unsized() { // NOTE: `np.size(self) == 0` here is never possible. - let zero = generator.get_size_type(ctx.ctx).const_zero(); + let zero = ctx.get_size_type().const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; Some(value) @@ -756,9 +721,9 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + _: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) + irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0) } } @@ -770,7 +735,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -834,7 +799,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); + assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into()); let indices = TypedArrayLikeAdapter::from( indices.as_slice_value(ctx, generator), @@ -867,7 +832,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let indices_size = indices.size(ctx, generator); let nidx_leq_ndims = ctx diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 4b4e07a1..3784193d 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -53,20 +53,16 @@ impl<'ctx> NDIterValue<'ctx> { /// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is 0-sized, this always returns false. #[must_use] - pub fn has_element( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self) + pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_nditer_has_element(ctx, *self) } /// Go to the next element. If `has_element()` is false, then this has undefined behavior. /// /// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is 0-sized, this can never be called. - pub fn next(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) { - irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); + pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_nditer_next(ctx, *self); } fn element_field( @@ -167,10 +163,10 @@ impl<'ctx> NDArrayValue<'ctx> { |generator, ctx| { Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) }, - |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), - |generator, ctx, nditer| { - nditer.next(generator, ctx); + |_, ctx, nditer| { + nditer.next(ctx); Ok(()) }, ) diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index 190a1e4f..3ac2795d 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -30,7 +30,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), ) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let zero = llvm_usize.const_zero(); let one = llvm_usize.const_int(1, false); diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 5027be58..f68931f7 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -70,7 +70,7 @@ impl<'ctx> NDArrayValue<'ctx> { dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); // Resolve negative indices - let size = self.size(generator, ctx); + let size = self.size(ctx); let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_shape = dst_ndarray.shape(); irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( @@ -84,10 +84,10 @@ impl<'ctx> NDArrayValue<'ctx> { gen_if_callback( generator, ctx, - |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| Ok(self.is_c_contiguous(ctx)), |generator, ctx| { // Reshape is possible without copying - dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.set_strides_contiguous(ctx); dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); Ok(()) @@ -97,7 +97,7 @@ impl<'ctx> NDArrayValue<'ctx> { unsafe { dst_ndarray.create_data(generator, ctx); } - dst_ndarray.copy_data_from(generator, ctx, *self); + dst_ndarray.copy_data_from(ctx, *self); Ok(()) }, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 600276b7..a0673a11 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1278,11 +1278,7 @@ impl<'a> BuiltinBuilder<'a> { let size = ctx .builder - .build_int_truncate_or_bit_cast( - ndarray.size(generator, ctx), - ctx.ctx.i32_type(), - "", - ) + .build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "") .unwrap(); Ok(Some(size.into())) }), From 8e614d83de169a5dbd4a36219a3ab37a09c236e7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 14 Jan 2025 18:20:09 +0800 Subject: [PATCH 52/80] [core] codegen: Add ProxyType::new overloads and refactor to use them --- nac3artiq/src/codegen.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 130 ++++++++++-------- nac3core/src/codegen/expr.rs | 42 ++---- nac3core/src/codegen/mod.rs | 6 +- nac3core/src/codegen/numpy.rs | 16 +-- nac3core/src/codegen/stmt.rs | 3 +- nac3core/src/codegen/test.rs | 4 +- nac3core/src/codegen/types/list.rs | 45 ++++-- nac3core/src/codegen/types/ndarray/array.rs | 11 +- .../src/codegen/types/ndarray/broadcast.rs | 20 ++- .../src/codegen/types/ndarray/contiguous.rs | 22 ++- .../src/codegen/types/ndarray/indexing.rs | 17 ++- nac3core/src/codegen/types/ndarray/map.rs | 14 +- nac3core/src/codegen/types/ndarray/mod.rs | 80 +++++++---- nac3core/src/codegen/types/ndarray/nditer.rs | 20 ++- nac3core/src/codegen/types/tuple.rs | 27 +++- nac3core/src/codegen/types/utils/slice.rs | 24 +++- nac3core/src/codegen/values/list.rs | 8 +- .../src/codegen/values/ndarray/broadcast.rs | 4 +- .../src/codegen/values/ndarray/contiguous.rs | 11 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 22 +-- nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/ndarray/view.rs | 2 +- 25 files changed, 320 insertions(+), 228 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index cb75606d..c968198b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -476,8 +476,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) - .map_value(arg.into_pointer_value(), None); + let ndarray = + NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -609,7 +609,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) + let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 96f8c700..911e3dc1 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -752,24 +752,20 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1015,24 +1011,20 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1652,7 +1644,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1686,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1715,8 +1707,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let q = q.as_base_value().as_basic_value_enum(); let r = r.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) - .construct_from_objects(ctx, [q, r], None); + let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( + ctx, + [q, r], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1746,8 +1741,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray1_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1775,7 +1770,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let u = u.as_base_value().as_basic_value_enum(); let s = s.as_base_value().as_basic_value_enum(); let vh = vh.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) .construct_from_objects(ctx, [u, s, vh], None); Ok(tuple.as_base_value().into()) } @@ -1796,7 +1791,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,8 +1833,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) - .construct_dyn_shape(generator, ctx, &[d0, d1], None); + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2).construct_dyn_shape( + generator, + ctx, + &[d0, d1], + None, + ); unsafe { out.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -1880,7 +1879,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1901,8 +1900,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let l = l.as_base_value().as_basic_value_enum(); let u = u.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) - .construct_from_objects(ctx, [l, u], None); + let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( + ctx, + [l, u], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1936,11 +1938,11 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) }; - let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into()) + let x2 = NDArrayType::new_unsized(ctx, ctx.ctx.f64_type().into()) .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,8 +1981,12 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) - .construct_const_shape(generator, ctx, &[1], None); + let det = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1).construct_const_shape( + generator, + ctx, + &[1], + None, + ); unsafe { det.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -2014,7 +2020,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2037,8 +2043,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let t = t.as_base_value().as_basic_value_enum(); let z = z.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) - .construct_from_objects(ctx, [t, z], None); + let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( + ctx, + [t, z], + None, + ); Ok(tuple.as_base_value().into()) } @@ -2059,7 +2068,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); @@ -2082,7 +2091,10 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let h = h.as_base_value().as_basic_value_enum(); let q = q.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) - .construct_from_objects(ctx, [h, q], None); + let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( + ctx, + [h, q], + None, + ); Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 00290d34..8f52e929 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1167,7 +1167,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( + list = ListType::new(ctx, &elem_ty).construct( generator, ctx, list_alloc_size.into_int_value(), @@ -1218,12 +1218,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( - generator, - ctx, - length, - Some("listcomp"), - ); + list = ListType::new(ctx, &elem_ty).construct(generator, ctx, length, Some("listcomp")); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1386,8 +1381,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) - .construct(generator, ctx, size, None); + let new_list = + ListType::new(ctx, &llvm_elem_ty).construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1474,7 +1469,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( + let new_list = ListType::new(ctx, &elem_llvm_ty).construct( generator, ctx, ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), @@ -1576,8 +1571,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let right = right.to_ndarray(generator, ctx); let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, llvm_common_dtype, &[left.get_type(), right.get_type()], ) @@ -1850,8 +1844,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .to_ndarray(generator, ctx); let result_ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ctx.ctx.i8_type().into(), &[left.get_type(), right.get_type()], ) @@ -2480,18 +2473,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let length = ctx.get_size_type().const_int(elements.len() as u64, false); let arr_str_ptr = if let Some(ty) = ty { - ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("list"), - ) + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("list")) } else { - ListType::new_untyped(generator, ctx.ctx).construct_empty( - generator, - ctx, - Some("list"), - ) + ListType::new_untyped(ctx).construct_empty(generator, ctx, Some("list")) }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { @@ -2970,12 +2954,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("ret"), - ); + let res_array_ret = + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("ret")); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f7483ef8..dcfa2b8c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -530,7 +530,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -540,7 +540,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( @@ -594,7 +594,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6700af45..3cdd1ef3 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,7 +120,7 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full( generator, context, &shape, @@ -223,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -251,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -349,8 +349,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ctx, Some("np_dot"), |generator, ctx| { - let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); - let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + let a_iter = NDIterType::new(ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(ctx).construct(generator, ctx, b); Ok((a_iter, b_iter)) }, |_, ctx, (a_iter, _b_iter)| { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index c3274057..85a894ac 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -448,8 +448,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( - generator, - ctx.ctx, + ctx, value.get_type().element_type(), broadcast_ndims, ) diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 6518d858..a58a9847 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -446,7 +446,7 @@ fn test_classes_list_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into()); + let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); } @@ -466,6 +466,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); + let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 9ea4acaa..637cced3 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -104,7 +104,7 @@ impl<'ctx> ListType<'ctx> { element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - let element_type = element_type.unwrap_or(llvm_usize.into()); + let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum()); let field_tys = Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); @@ -112,26 +112,45 @@ impl<'ctx> ListType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl( + ctx: &'ctx Context, + element_type: Option>, + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + + Self { ty: llvm_list, item: element_type, llvm_usize } + } + /// Creates an instance of [`ListType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, element_type: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - - Self { ty: llvm_list, item: Some(element_type), llvm_usize } + Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx)) } /// Creates an instance of [`ListType`] with an unknown element type. #[must_use] - pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, None, ctx.get_size_type()) + } - Self { ty: llvm_list, item: None, llvm_usize } + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, None, generator.get_size_type(ctx)) } /// Creates an [`ListType`] from a [unifier type][Type]. @@ -159,11 +178,7 @@ impl<'ctx> ListType<'ctx> { Some(ctx.get_llvm_type(generator, elem_type)) }; - Self { - ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), - item: llvm_elem_type, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) } /// Creates an [`ListType`] from a [`PointerType`]. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index b0c9d637..70611127 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -44,7 +44,7 @@ impl<'ctx> NDArrayType<'ctx> { assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); - let list_value = list.as_i8_list(generator, ctx); + let list_value = list.as_i8_list(ctx); // Validate `list` has a consistent shape. // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. @@ -61,8 +61,8 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -96,8 +96,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, 1) - .construct_uninitialized(generator, ctx, name); + let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name); // Set data let data = ctx @@ -168,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) + NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 5ee28454..3a1fd8da 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -79,15 +79,27 @@ impl<'ctx> ShapeEntryType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`ShapeEntryType`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ty, llvm_usize } } + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index f4a8b73d..c751d573 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -117,17 +117,26 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); + + Self { ty: llvm_cndarray, item, llvm_usize } + } + /// Creates an instance of [`ContiguousNDArrayType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type()) + } + + /// Creates an instance of [`ContiguousNDArrayType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); - - Self { ty: llvm_cndarray, item, llvm_usize } + Self::new_impl(ctx, item, generator.get_size_type(ctx)) } /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. @@ -140,9 +149,8 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); - Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } + Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) } /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 644e173c..3e4e1362 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -75,14 +75,25 @@ impl<'ctx> NDIndexType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ndindex, llvm_usize } } + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 6fdd9e12..ae24458c 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -46,9 +46,8 @@ impl<'ctx> NDArrayType<'ctx> { let out_ndarray = match out { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. - let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) - .construct_uninitialized(generator, ctx, None); + let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims) + .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, ctx, @@ -70,7 +69,7 @@ impl<'ctx> NDArrayType<'ctx> { }; // Map element-wise and store results into `mapped_ndarray`. - let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray); gen_for_callback( generator, ctx, @@ -80,9 +79,7 @@ impl<'ctx> NDArrayType<'ctx> { let other_nditers = broadcast_result .ndarrays .iter() - .map(|ndarray| { - NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) - }) + .map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray)) .collect_vec(); Ok((nditer, other_nditers)) }, @@ -169,8 +166,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { // Promote all input to ndarrays and map through them. let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); let ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ret_dtype, &inputs.iter().map(NDArrayValue::get_type).collect_vec(), ) diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index a7bcb7ef..fe73307d 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -107,24 +107,56 @@ impl<'ctx> NDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDArrayType`]. - #[must_use] - pub fn new( - generator: &G, + fn new_impl( ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ndims: u64, + llvm_usize: IntType<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self { + Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + ) -> Self { + Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx)) + } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more /// `ndarray` operands. #[must_use] - pub fn new_broadcast( + pub fn new_broadcast( + ctx: &CodeGenContext<'ctx, '_>, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new_impl( + ctx.ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, @@ -132,20 +164,28 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) + Self::new_impl( + ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + generator.get_size_type(ctx), + ) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] - pub fn new_unsized( + pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self { + Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. + #[must_use] + pub fn new_unsized_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - - NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } + Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx)) } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -158,15 +198,9 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); let ndims = extract_ndims(&ctx.unifier, ndims); - NDArrayType { - ty: Self::llvm_type(ctx.ctx, llvm_usize), - dtype: llvm_dtype, - ndims, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) } /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. @@ -304,7 +338,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -339,7 +373,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -389,8 +423,8 @@ impl<'ctx> NDArrayType<'ctx> { .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type()) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name); ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap(); ndarray } diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c77e4571..1d83742f 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -86,15 +86,27 @@ impl<'ctx> NDIterType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDIter`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_nditer = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_nditer, llvm_usize } } + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 947f95ad..5c736528 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -32,17 +32,34 @@ impl<'ctx> TupleType<'ctx> { ctx.struct_type(tys, false) } + fn new_impl( + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + /// Creates an instance of [`TupleType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self { + Self::new_impl( + ctx.ctx, + &tys.iter().map(BasicType::as_basic_type_enum).collect_vec(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>], ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_tuple = Self::llvm_type(ctx, tys); - - Self { ty: llvm_tuple, llvm_usize } + Self::new_impl(ctx, tys, generator.get_size_type(ctx)) } /// Creates an [`TupleType`] from a [unifier type][Type]. diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index fa5a3474..0ef4d1b0 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -122,19 +122,31 @@ impl<'ctx> SliceType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. - #[must_use] - pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, int_ty); Self { ty: llvm_ty, int_ty, llvm_usize } } + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) + } + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. #[must_use] - pub fn new_usize(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - Self::new(ctx, llvm_usize, llvm_usize) + pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type()) + } + + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. + #[must_use] + pub fn new_usize_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) } /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 08d2b6b5..4ba5b6af 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -114,13 +114,9 @@ impl<'ctx> ListValue<'ctx> { /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. #[must_use] - pub fn as_i8_list( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> ListValue<'ctx> { + pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index b145746e..b5182a2b 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -104,7 +104,7 @@ impl<'ctx> NDArrayValue<'ctx> { assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims) .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, @@ -147,7 +147,7 @@ fn broadcast_shapes<'ctx, G, Shape>( + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { let llvm_usize = ctx.get_size_type(); - let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(ctx); assert!(in_shape_entries .iter() diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 52082df6..0fbb85f0 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -117,8 +117,8 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { - let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca_var(generator, ctx, self.name); + let result = + ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name); // Set ndims and shape. let ndims = self.llvm_usize.const_int(self.ndims, false); @@ -178,8 +178,11 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) - .construct_uninitialized(generator, ctx, carray.name); + let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized( + generator, + ctx, + carray.name, + ); // Copy shape and update strides let shape = carray.load_shape(ctx); diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 1a96522b..60c9c3b7 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -128,11 +128,10 @@ impl<'ctx> NDArrayValue<'ctx> { indices: &[RustNDIndex<'ctx>], ) -> Self { let dst_ndims = self.deduce_ndims_after_indexing_with(indices); - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); - let indices = - NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices); + let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices); irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); dst_ndarray @@ -245,8 +244,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type()) - .alloca_var(generator, ctx, None); + SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index a24316b4..f12d36c1 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -108,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) + let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 595345e8..705412e0 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -377,12 +377,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Create the strides tuple of this ndarray like @@ -411,12 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. @@ -998,10 +990,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ) -> NDArrayValue<'ctx> { match self { ScalarOrNDArray::NDArray(ndarray) => *ndarray, - ScalarOrNDArray::Scalar(scalar) => { - NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) - .construct_unsized(generator, ctx, scalar, None) - } + ScalarOrNDArray::Scalar(scalar) => NDArrayType::new_unsized(ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 3784193d..dd900d64 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -160,9 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator, ctx, Some("ndarray_foreach"), - |generator, ctx| { - Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) - }, + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |_, ctx, nditer| { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index f68931f7..9ab3d306 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -65,7 +65,7 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); From 762a2447c3f57e9a09c014d39fe2869499e0dbd9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 16:08:55 +0800 Subject: [PATCH 53/80] [core] codegen: Remove obsolete comments Comments regarding the need for `llvm.stack{save,restore}` is obsolete now that `NDIter::indices` is allocated at the beginning of the function. --- nac3core/src/codegen/types/ndarray/nditer.rs | 5 ----- nac3core/src/codegen/values/ndarray/nditer.rs | 4 ---- 2 files changed, 9 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 1d83742f..45b6bb0a 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,11 +163,6 @@ impl<'ctx> NDIterType<'ctx> { } /// Allocate an [`NDIter`] that iterates through the given `ndarray`. - /// - /// Note: This function allocates an array on the stack at the current builder location, which - /// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to - /// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the - /// [`NDIter`] is no longer needed. #[must_use] pub fn construct( &self, diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index dd900d64..86f370e5 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -137,10 +137,6 @@ impl<'ctx> NDArrayValue<'ctx> { /// /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// get properties of the current iteration (e.g., the current element, indices, etc.) - /// - /// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and - /// after invoking this function respectively. See [`NDIterType::construct`] for an explanation - /// on why this is suggested. pub fn foreach<'a, G, F>( &self, generator: &mut G, From 357970a793b837b2553cf44e5c9ee3d92104b1d2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:19:18 +0800 Subject: [PATCH 54/80] [core] codegen/stmt: Add build_{break,continue}_branch functions --- nac3core/src/codegen/expr.rs | 4 +--- nac3core/src/codegen/stmt.rs | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8f52e929..30b8dcd3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2122,9 +2122,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ) .unwrap(); - ctx.builder - .build_unconditional_branch(hooks.exit_bb) - .unwrap(); + hooks.build_break_branch(&ctx.builder); Ok(()) }, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 85a894ac..7b99bc26 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,6 +1,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, + builder::Builder, types::{BasicType, BasicTypeEnum}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, @@ -662,11 +663,25 @@ pub fn gen_for( #[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] pub struct BreakContinueHooks<'ctx> { /// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop. - pub exit_bb: BasicBlock<'ctx>, + exit_bb: BasicBlock<'ctx>, /// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration /// of the loop. - pub latch_bb: BasicBlock<'ctx>, + latch_bb: BasicBlock<'ctx>, +} + +impl<'ctx> BreakContinueHooks<'ctx> { + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the exit + /// [`BasicBlock`], as if by calling `break`. + pub fn build_break_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.exit_bb).unwrap(); + } + + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the latch + /// [`BasicBlock`], as if by calling `continue`. + pub fn build_continue_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.latch_bb).unwrap(); + } } /// Generates a C-style `for` construct using lambdas, similar to the following C code: From 18e8e5269fbbac609021b60bf95a6f1125166f3f Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:35:31 +0800 Subject: [PATCH 55/80] [core] codegen/values/ndarray: Add fold utilities Needed for np_{any,all}. --- nac3core/src/codegen/values/ndarray/fold.rs | 101 ++++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 1 + 2 files changed, 102 insertions(+) create mode 100644 nac3core/src/codegen/values/ndarray/fold.rs diff --git a/nac3core/src/codegen/values/ndarray/fold.rs b/nac3core/src/codegen/values/ndarray/fold.rs new file mode 100644 index 00000000..7c8aebd4 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/fold.rs @@ -0,0 +1,101 @@ +use inkwell::values::{BasicValue, BasicValueEnum}; + +use super::{NDArrayValue, NDIterValue, ScalarOrNDArray}; +use crate::codegen::{ + stmt::{gen_for_callback, BreakContinueHooks}, + types::ndarray::NDIterType, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the + /// final value. + /// + /// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance + /// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the + /// properties of the current iterated element. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + V, + NDIterValue<'ctx>, + ) -> Result, + { + let acc_ptr = + generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap(); + ctx.builder.build_store(acc_ptr, init).unwrap(); + + gen_for_callback( + generator, + ctx, + Some("ndarray_fold"), + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), + |generator, ctx, hooks, nditer| { + let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap(); + let acc = f(generator, ctx, hooks, acc, nditer)?; + ctx.builder.build_store(acc_ptr, acc).unwrap(); + Ok(()) + }, + |_, ctx, nditer| { + nditer.next(ctx); + Ok(()) + }, + )?; + + let acc = ctx.builder.build_load(acc_ptr, "").unwrap(); + Ok(V::try_from(acc).unwrap()) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// See [`NDArrayValue::fold`]. + /// + /// The primary differences between this function and `NDArrayValue::fold` are: + /// + /// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not + /// available if this instance represents a scalar value. + /// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will + /// be created if this instance represents a scalar value. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Option<&BreakContinueHooks<'ctx>>, + V, + BasicValueEnum<'ctx>, + ) -> Result, + { + match self { + ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v), + ScalarOrNDArray::NDArray(v) => { + v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| { + let elem = nditer.get_scalar(ctx); + f(generator, ctx, Some(&hooks), acc, elem) + }) + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 705412e0..1bf5db31 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -30,6 +30,7 @@ pub use nditer::*; mod broadcast; mod contiguous; +mod fold; mod indexing; mod map; mod matmul; From 1cfaa1a77952f12456440557ae1ad1a3493253ff Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:36:03 +0800 Subject: [PATCH 56/80] [core] toplevel: Implement np_{any,all} --- nac3core/src/toplevel/builtins.rs | 70 +++++++++++++++++++++++++-- nac3core/src/toplevel/helper.rs | 4 ++ nac3standalone/demo/interpret_demo.py | 2 + nac3standalone/demo/src/ndarray.py | 56 +++++++++++++++++++++ 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index a0673a11..e06366c5 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -6,7 +6,8 @@ use strum::IntoEnumIterator; use super::{ helper::{ - debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims, + make_exception_fields, PrimDef, PrimDefDetails, }, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, @@ -15,9 +16,12 @@ use crate::{ codegen::{ builtin_fns, numpy::*, - stmt::exn_constructor, + stmt::{exn_constructor, gen_if_callback}, types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, + values::{ + ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, + ProxyValue, RangeValue, + }, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim), + PrimDef::FunNpSin | PrimDef::FunNpCos | PrimDef::FunNpTan @@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> { ) } + fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]); + + let param_ty = &[(self.num_or_ndarray_ty.ty, "a")]; + let ret_ty = self.primitives.bool; + let var_map = &self.num_or_ndarray_var_map; + let codegen_callback: Box = + Box::new(move |ctx, _, fun, args, generator| { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i1_k0 = llvm_i1.const_zero(); + let llvm_i1_k1 = llvm_i1.const_all_ones(); + + let a_ty = fun.0.args[0].ty; + let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val)); + let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty); + + let (init, sc_val) = match prim { + PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1), + PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0), + _ => unreachable!(), + }; + + let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| { + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, acc, sc_val, "") + .unwrap()) + }, + |_, ctx| { + if let Some(hooks) = hooks { + hooks.build_break_branch(&ctx.builder); + } + Ok(()) + }, + |_, _| Ok(()), + )?; + + let is_truthy = + builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value(); + + Ok(match prim { + PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(), + PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(), + _ => unreachable!(), + }) + })?; + + Ok(Some(acc.as_basic_value_enum())) + }); + + create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback) + } + /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index de90a41b..72d3eaa6 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -111,6 +111,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunNpAny, + FunNpAll, // Linalg functions FunNpDot, @@ -305,6 +307,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunNpAny => fun("np_any", None), + PrimDef::FunNpAll => fun("np_all", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index fa91ed3c..180d24f0 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -232,6 +232,8 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter + module.np_any = np.any + module.np_all = np.all # SciPy Math functions module.sp_spec_erf = special.erf diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index b668860f..d077b82d 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_any(): + s0 = 0 + output_bool(np_any(s0)) + s1 = 1 + output_bool(np_any(s1)) + + x1 = np_identity(5) + y1 = np_any(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_any(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_any(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_any(x4) + output_ndarray_float_2(x4) + output_bool(y4) + +def test_ndarray_all(): + s0 = 0 + output_bool(np_all(s0)) + s1 = 1 + output_bool(np_all(s1)) + + x1 = np_identity(5) + y1 = np_all(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_all(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_all(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_all(x4) + output_ndarray_float_2(x4) + output_bool(y4) + def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1851,6 +1904,9 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() + test_ndarray_any() + test_ndarray_all() + test_ndarray_dot() test_ndarray_cholesky() test_ndarray_qr() From 933804e2707d3e961496d60615693144351f0049 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 15 Jan 2025 21:18:45 +0800 Subject: [PATCH 57/80] update dependencies --- Cargo.lock | 111 ++++++++++++++++++++++++++++------------------------- flake.lock | 6 +-- 2 files changed, 62 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 646700d4..c1d93528 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,11 +65,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] @@ -105,9 +106,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "block-buffer" @@ -126,9 +127,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.7" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -141,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -151,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -163,14 +164,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -472,7 +473,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -581,9 +582,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "llvm-sys" @@ -610,9 +611,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "memchr" @@ -678,7 +679,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", "trybuild", ] @@ -761,45 +762,45 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_macros", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ - "phf_shared 0.11.2", + "phf_shared 0.11.3", "rand", ] [[package]] name = "phf_macros" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -808,16 +809,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -873,9 +874,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -927,7 +928,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -940,7 +941,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1049,9 +1050,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ "bitflags", "errno", @@ -1110,14 +1111,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -1174,6 +1175,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "smallvec" version = "1.13.2" @@ -1226,7 +1233,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1242,9 +1249,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.94" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -1326,7 +1333,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1604,9 +1611,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.22" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] @@ -1638,5 +1645,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] diff --git a/flake.lock b/flake.lock index 7672c219..3e4af709 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1735834308, - "narHash": "sha256-dklw3AXr3OGO4/XT1Tu3Xz9n/we8GctZZ75ZWVqAVhk=", + "lastModified": 1736798957, + "narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "6df24922a1400241dae323af55f30e4318a6ca65", + "rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3", "type": "github" }, "original": { From c15062ab4c1ff391b327b91f580b541f37cdde56 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 15 Jan 2025 21:33:58 +0800 Subject: [PATCH 58/80] msys2: update --- nix/windows/msys2_packages.nix | 162 ++++++++++++++++----------------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/nix/windows/msys2_packages.nix b/nix/windows/msys2_packages.nix index 0ac1aa8f..0859244d 100644 --- a/nix/windows/msys2_packages.nix +++ b/nix/windows/msys2_packages.nix @@ -1,15 +1,15 @@ { pkgs } : [ (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0frb5k16bbxdf8g379d16vl3qrh7n9pydn83gpfxpvwf3qlvnzyl"; - name = "mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1gv6hbqvfgjzirpljql1shlchldmf5ww3rfsspg90pq1frnwavjl"; + name = "mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0wh5km0v8j50pqz9bxb4f0w7r8zhsvssrjvc94np53iq8wjagk86"; - name = "mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1wbkvrx14ahc04cgkydvlxwmsl8jfnqwhy9sy4kn4wkdzmlcp1ax"; + name = "mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -19,15 +19,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; - sha256 = "1g2bkhgf60dywccxw911ydyigf3m25yqfh81m5099swr7mjsmzyf"; - name = "mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; + sha256 = "0vn5xgx9jjg66f8r9ylm9220qdbjdkffykfl6nwj14zv9y7xh4nj"; + name = "mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; - sha256 = "0ll6ci6d3mc7g04q0xixjc209bh8r874dqbczgns69jsad3wg6mi"; - name = "mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; + sha256 = "0wbp5pmrr0rk4mx7d1frvqlk4a061zw31zscs57srmvl0wv3pi2a"; + name = "mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -55,69 +55,69 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1clrbm8dk893byj8s15pgcgqqijm2zkd10zgyakamd8m354kj9q4"; - name = "mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0fpsnfyf0bg39a4ygzga06sr4wv4jp1jnc8lk6sr3z0nim0nlhjn"; + name = "mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1iz2c9475h8p20ydpp0znbhyb62rlrk7wr7xl7cmwbam7wkwr8rn"; - name = "mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0whqs9nvfmgxj3c83px6dipcdw9zi858kgd8130201fy1mbnafp1"; + name = "mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1hidciwlakxrp4kyb0j2v6g4lv76nn834g6b88w1j94fk3qc765d"; - name = "mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0rmzri7h043i73jy3c2jcrg3hy40dr5s9n96kmxgaghfhvlpilps"; + name = "mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1m1yhjkgzlbk10sv966qk4yji009ga0lr25gpgj2w7mcd2wixcr3"; - name = "mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; + sha256 = "04cqlh35asvlh06nmhwnx9h0yrqk8zxd9lpzxmm1xh64kvm9maxn"; + name = "mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "08gxc7h2achckknn6fz3p6yi7gxxvbaday8fpm4j56c4sa04n0df"; - name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "05zsqgq8zwdcfacyqdxdjcf80447bgnrz71xv5cds0y135yziy7l"; + name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "0fxd1pb197ki0gzw6z8gmd6wgpd9d28js6cp5d31d55kw7d1vz13"; - name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "12fkxpk7rwy36snvvc7sdivx81pd4ckzh5ilyh7gl6ly4qayppp6"; + name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1a8pjyhrzpc2z3784xxwix4i7yrz03ygnsk1wv9k0yq8m8wi9nbw"; - name = "mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; + sha256 = "102bbv5acq1fvrfn8bp1x3503cb8hvcxmlpr86qsba4vm11l0wrw"; + name = "mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "140m312jx1sywqjkvfij69d268m4jpdmilq5bb8khkf0ayb16036"; - name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1sris0qczxk5px9xy85976hbmqrpg49ns7yyzd9p455ckf740cid"; + name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "017j4h511wg37bacym73f8g6s0jcfgzbzabzxpc6anr3gy4kkpbg"; - name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1r0m5xpsxdl00a2daj4p0wgl6037700pvw6p6zl91h1dr092r6pa"; + name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; - sha256 = "11f4i4ai2bzvq6f06vxk1ymv7056c9707vdw489f1i2bdrf0c0ii"; - name = "mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0j4a642fpnvqs79chhinc8r5q53q1wllmc1bzb01a4y7w9rqg4hw"; + name = "mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; - sha256 = "0nxs571vb4f1i5vp91134p5blns9ml2r25nx6kdlg0zhd5x85kvm"; - name = "mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; + sha256 = "0nrz9788grl50nkbhxswry143rrwpdnc6pk6f0k30kcp19qq6y2d"; + name = "mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -127,9 +127,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; - sha256 = "1mpn397qsdz3l2fav6ymwjlj96ialn9m8sldii3ymbcyhranl3xx"; - name = "mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; + sha256 = "1dppwwx3wrn0lzrlk2q7bpsainbidrpw1ndp1aasyv42xhxl1sn1"; + name = "mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -139,9 +139,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; - sha256 = "13nz49li39z1zgfx1q9jg4vrmyrmqb6qdq0nqshidaqc6zr16k3g"; - name = "mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; + sha256 = "1zg58qbfybyqzcj0dalb13l48f9jsras318h02rka65r7wi0pdcg"; + name = "mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -169,9 +169,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; - sha256 = "1q5nxhsk04gidz66ai5wgd4dr04lfyakkfja9p0r5hrgg4ppqqjg"; - name = "mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; + sha256 = "0c36lg63imzw8i6j1ard42v5wgzpc83phzk8lvifvm0djndq2bbj"; + name = "mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -193,9 +193,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; - sha256 = "1p7q47fin12vzyf126v1azbbpgpa0y6ighfh6mbfdb6zcyq74kbd"; - name = "mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; + sha256 = "0kd2f7yh90815kyldxvdy8c6jyxyw0wv4f7k3shwp98w874m0mxd"; + name = "mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -271,15 +271,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; - sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc"; - name = "mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; + sha256 = "0gdn1351knjwgsqgyaa3l55qs135k7dn6mlf04vzjxlc1895wx5z"; + name = "mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; - sha256 = "139f91r392c68hsajm0c81690pmzkywb0p4x8ms8ms53ncxnz6gz"; - name = "mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; + sha256 = "1xjjwgkqf2j97pcx0yd6j0lgmzgbgqjjf0s7j29mc03g89fhdhw0"; + name = "mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -289,9 +289,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; - sha256 = "1hlfj9g4s767s502sawwbcv4a0xd3ym3ip4jswmhq48wh5050iyb"; - name = "mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; + sha256 = "0f98pzrwsxil90n55hz2ym2x2rzrrjrmnj8i2203n189qbxbg2c9"; + name = "mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -331,32 +331,32 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; - sha256 = "1v15j2pzy9wj4n1rjngdi2hf8h0l9z4lri3xb86yvdv1xl2msj6h"; - name = "mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; + sha256 = "0lksgrmylvpr7yyjcc1szm30pnag7ixrj7vhdql1ryi4k9309v8s"; + name = "mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; - sha256 = "1pn1fbj74rx837s9z8gqs4b0cr7kqi5m1m2mi9ibjpw64m1aqwxv"; - name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0d3mm26hnw716n0ppzqhydxcgm4im081hiiy6l4zp267ad3kfg93"; + name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; - sha256 = "18p1zhf7h3k3phf3bl483jg3k7y9zq375z6ww75g62158ic9lfyc"; - name = "mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; + sha256 = "006f2s12jmk35rppkp20rlm7k4kknsnh5h4krqs2ry2rd6qqkk9h"; + name = "mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; - sha256 = "1kiy7ail04ias47xbbhl9vpsz02g0g3f29ncgx5gcks9vgqldp6m"; - name = "mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; + sha256 = "0sgkhax9cwmkkrfrir45l91h6pgg339gaw6147gsayf8h8ag4brg"; + name = "mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; - sha256 = "03l04kjmy5p9whaw0h619gdg7yw1gxbz8phifq4pzh3c1wlw7yfd"; - name = "mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; + sha256 = "12ivpaj967y4bi8396q3fpii4fy5aakidxpv16rkyg1b831k0h93"; + name = "mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; }) ] From 4bd5349381d4b6ed4b3bfe430ffe13649f24050b Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 10:49:13 +0800 Subject: [PATCH 59/80] [core] add attributes to class string --- nac3core/src/toplevel/helper.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72d3eaa6..eb72d37a 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,21 +379,29 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { + TopLevelDef::Class { + name, ancestors, fields, methods, attributes, type_vars, .. + } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); + let attributes_str = attributes + .iter() + .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) + .collect_vec(); + let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( - "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", + "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), + attributes_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) From febfd1241d64486b6496b4a1c92ea18aa743e4d7 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 11:06:14 +0800 Subject: [PATCH 60/80] [core] add module type --- nac3artiq/src/lib.rs | 1 + nac3core/src/codegen/expr.rs | 2 +- nac3core/src/toplevel/composer.rs | 3 ++- nac3core/src/toplevel/helper.rs | 7 +++++++ nac3core/src/toplevel/mod.rs | 12 ++++++++++++ nac3core/src/typecheck/type_inferencer/mod.rs | 3 ++- 6 files changed, 25 insertions(+), 3 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index d35e66d1..59d4dbed 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -713,6 +713,7 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } + TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), } } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 30b8dcd3..6d2057e1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -979,7 +979,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( TopLevelDef::Class { .. } => { return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) } - TopLevelDef::Variable { .. } => unreachable!(), + TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(), } } .or_else(|_: String| { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bd9a9214..b293fb4c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,7 +101,8 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } + | TopLevelDef::Module { name, .. } => name.to_string(), TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index eb72d37a..72502aa4 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,6 +379,13 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { + TopLevelDef::Module { name, attributes, .. } => { + let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec(); + format!( + "Module {{\nname: {:?},\nattributes{:?}\n}}", + name, method_str + ) + } TopLevelDef::Class { name, ancestors, fields, methods, attributes, type_vars, .. } => { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index cba2f5e7..88c007ec 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -92,6 +92,18 @@ pub struct FunInstance { #[derive(Debug, Clone)] pub enum TopLevelDef { + Module { + /// Name of the module + name: StrRef, + /// Module ID used for [`TypeEnum`] + module_id: DefinitionId, + /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module + attributes: HashMap, + /// Symbol resolver of the module defined the class. + resolver: Option>, + /// Definition location. + loc: Option, + }, Class { /// Name for error messages and symbols. name: StrRef, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 742fa197..a58045be 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2734,7 +2734,8 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } => (name, false), + TopLevelDef::Class { name, .. } + | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) From 7fac801936ff2cd5ef1ebdba73559eab2642a2cc Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 11:06:50 +0800 Subject: [PATCH 61/80] [artiq] add module primitive type --- nac3artiq/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 59d4dbed..4174fc8c 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -159,6 +159,7 @@ pub struct PrimitivePythonId { generic_alias: (u64, u64), virtual_id: u64, option: u64, + module: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -1097,6 +1098,7 @@ impl Nac3 { tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), + module: get_attr_id(types_mod, "ModuleType"), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); From f15a64cc1b3e949dab9551698b4ed9d1c68d19e7 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 12:05:11 +0800 Subject: [PATCH 62/80] [artiq] register modules --- nac3artiq/src/lib.rs | 25 +++++++++++++++++++++---- nac3core/src/toplevel/composer.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 4174fc8c..78f427e5 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -43,7 +43,7 @@ use nac3core::{ OptimizationLevel, }, nac3parser::{ - ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, + ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }, symbol_resolver::SymbolResolver, @@ -470,12 +470,14 @@ impl Nac3 { ]; add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); + // Stores a mapping from module id to attributes let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; + let module_name: String = py_module.getattr("__name__")?.extract()?; let helper = helper.clone(); let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { @@ -490,7 +492,7 @@ impl Nac3 { } else { class_obj = None; } - let (name_to_pyid, resolver) = + let (name_to_pyid, resolver, _, _) = module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = @@ -519,9 +521,10 @@ impl Nac3 { }))) as Arc; let name_to_pyid = Rc::new(name_to_pyid); + let module_location = ast::Location::new(1, 1, stmt.location.file); module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone())); - (name_to_pyid, resolver) + .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); + (name_to_pyid, resolver, module_name, Some(module_location)) }); let (name, def_id, ty) = composer @@ -595,6 +598,20 @@ impl Nac3 { } } + // Adding top level module definitions + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { + let def_id= composer.register_top_level_module( + module_name, + module_name_to_pyid, + module_resolver, + module_location + ).map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; + + self.pyid_to_def.write().insert(module_id, def_id); + } + let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; let mut name_to_pyid: HashMap = HashMap::new(); let module = PyModule::new(py, "tmp")?; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index b293fb4c..a4ca27f1 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -202,6 +202,35 @@ impl TopLevelComposer { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } + /// register top level modules + pub fn register_top_level_module( + &mut self, + module_name: String, + name_to_pyid: Rc>, + resolver: Arc, + location: Option + ) -> Result { + let mut attributes: HashMap = HashMap::new(); + for (name, _) in name_to_pyid.iter() { + if let Ok(def_id) = resolver.get_identifier_def(*name) { + // Avoid repeated attribute instances resulting from multiple imports of same module + if self.defined_names.contains(&format!("{module_name}.{name}")) { + attributes.insert(*name, def_id); + } + }; + } + let module_def = TopLevelDef::Module { + name: module_name.clone().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + attributes, + resolver: Some(resolver), + loc: location + }; + + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); + Ok(DefinitionId(self.definition_ast_list.len() - 1)) + } + /// register, just remember the names of top level classes/function /// and check duplicate class/method/function definition pub fn register_top_level( From ce40a46f8a44f3d095e236e7773790e097b6c048 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 10:54:07 +0800 Subject: [PATCH 63/80] [core] add module type --- nac3core/src/codegen/concrete_type.rs | 13 ++ nac3core/src/typecheck/type_inferencer/mod.rs | 137 ++++++++++-------- nac3core/src/typecheck/typedef/mod.rs | 33 ++++- 3 files changed, 118 insertions(+), 65 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index d5c1fc38..f0c92ed8 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -205,6 +205,19 @@ impl ConcreteTypeStore { }) .collect(), }, + TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule { + module_id: *module_id, + methods: attributes + .iter() + .filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) { + TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None, + _ => Some(( + *name, + (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1), + )), + }) + .collect(), + }, TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a58045be..7ce659f3 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2008,72 +2008,90 @@ impl Inferencer<'_> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { - // just a fast path - match (fields.get(&attr), ctx == ExprContext::Store) { - (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), - (Some((ty, false)), true) => report_type_error( - TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), - Some(value.location), - self.unifier, - ), - (None, mutable) => { - // Check whether it is a class attribute - let defs = self.top_level.definitions.read(); - let result = { - if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { - attributes.iter().find_map(|f| { - if f.0 == attr { - return Some(f.1); - } - None - }) - } else { - None - } - }; - match result { - Some(res) if !mutable => Ok(res), - Some(_) => report_error( - &format!("Class Attribute `{attr}` is immutable"), - value.location, - ), - None => report_type_error( - TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), - Some(value.location), - self.unifier, - ), - } - } - } - } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 - let result = { - self.top_level.definitions.read().iter().find_map(|def| { - if let Some(rear_guard) = def.try_read() { - if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { - if name.to_string() == self.unifier.stringify(sign.ret) { - return attributes.iter().find_map(|f| { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, fields, .. } => { + // just a fast path + match (fields.get(&attr), ctx == ExprContext::Store) { + (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, false)), true) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { if f.0 == attr { - return Some(f.clone().1); + return Some(f.1); } None - }); + }) + } else { + None } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), } } - None - }) - }; - match result { - Some(f) if ctx != ExprContext::Store => Ok(f), - Some(_) => { - report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) } - None => self.infer_general_attribute(value, attr, ctx), } - } else { - self.infer_general_attribute(value, attr, ctx) + TypeEnum::TFunc(sign) => { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => self.infer_general_attribute(value, attr, ctx), + } + } + TypeEnum::TModule { attributes, .. } => { + match (attributes.get(&attr), ctx == ExprContext::Load) { + (Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, true)), false) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, _) => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), + } + } + _ => self.infer_general_attribute(value, attr, ctx), } } @@ -2734,8 +2752,7 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } - | TopLevelDef::Module { name, .. } => (name, false), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e190c4c4..f2f9ed6f 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -270,6 +270,19 @@ pub enum TypeEnum { /// A function type. TFunc(FunSignature), + + /// Module Type + TModule { + /// The [`DefinitionId`] of this object type. + module_id: DefinitionId, + + /// The attributes present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). + attributes: Mapping, + }, } impl TypeEnum { @@ -284,6 +297,7 @@ impl TypeEnum { TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", + TypeEnum::TModule { .. } => "TModule", } } } @@ -593,7 +607,8 @@ impl Unifier { | TLiteral { .. } // functions are instantiated for each call sites, so the function type can contain // type variables. - | TFunc { .. } => true, + | TFunc { .. } + | TModule { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, @@ -1315,10 +1330,12 @@ impl Unifier { || format!("{id}"), |top_level| { let top_level_def = &top_level.definitions.read()[id]; - let TopLevelDef::Class { name, .. } = &*top_level_def.read() else { - unreachable!("expected class definition") + let top_level_def = top_level_def.read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + &*top_level_def + else { + unreachable!("expected module/class definition") }; - name.to_string() }, ) @@ -1446,6 +1463,10 @@ impl Unifier { let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); format!("fn[[{params}], {ret}]") } + TypeEnum::TModule { module_id, .. } => { + let name = obj_to_name(module_id.0); + name.to_string() + } } } @@ -1521,7 +1542,9 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => { + None + } TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty, is_vararg_ctx } => { let mut new_ty = Cow::from(ty); From 32f24261f280cfe36d75a08313697e540822655e Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 11:08:55 +0800 Subject: [PATCH 64/80] [artiq] add global variables to modules --- nac3artiq/src/lib.rs | 41 +++++--- nac3artiq/src/symbol_resolver.rs | 160 +++++++++++++++++++++++++++++- nac3core/src/toplevel/composer.rs | 56 +++++++---- nac3core/src/toplevel/helper.rs | 9 +- nac3core/src/toplevel/mod.rs | 6 +- 5 files changed, 234 insertions(+), 38 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 78f427e5..ba6c4fae 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -277,6 +277,10 @@ impl Nac3 { } }) } + // Allow global variable declaration with `Kernel` type annotation + StmtKind::AnnAssign { ref annotation, .. } => { + matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) + } _ => false, }; @@ -522,8 +526,15 @@ impl Nac3 { as Arc; let name_to_pyid = Rc::new(name_to_pyid); let module_location = ast::Location::new(1, 1, stmt.location.file); - module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); + module_to_resolver_cache.insert( + module_id, + ( + name_to_pyid.clone(), + resolver.clone(), + module_name.clone(), + Some(module_location), + ), + ); (name_to_pyid, resolver, module_name, Some(module_location)) }); @@ -599,15 +610,19 @@ impl Nac3 { } // Adding top level module definitions - for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { - let def_id= composer.register_top_level_module( - module_name, - module_name_to_pyid, - module_resolver, - module_location - ).map_err(|e| { - CompileError::new_err(format!("compilation failed\n----------\n{e}")) - })?; + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in + module_to_resolver_cache + { + let def_id = composer + .register_top_level_module( + &module_name, + &module_name_to_pyid, + module_resolver, + module_location, + ) + .map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; self.pyid_to_def.write().insert(module_id, def_id); } @@ -731,7 +746,9 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } - TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), + TopLevelDef::Module { .. } => { + unreachable!("Type module cannot be decorated with @rpc") + } } } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d9768669..4b398a9b 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -23,7 +23,7 @@ use nac3core::{ inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -674,6 +674,48 @@ impl InnerResolver { }) }); + // check if obj is module + if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)? + == self.primitive_ids.module + && self.pyid_to_def.read().contains_key(&py_obj_id) + { + let def_id = self.pyid_to_def.read()[&py_obj_id]; + let def = defs[def_id.0].read(); + let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } = + &*def + else { + unreachable!("must be a module here"); + }; + // Construct the module return type + let mut module_attributes = HashMap::new(); + for (name, _) in attributes { + let attribute_obj = obj.getattr(name.to_string().as_str())?; + let attribute_ty = + self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?; + if let Ok(attribute_ty) = attribute_ty { + module_attributes.insert(*name, (attribute_ty, false)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + for name in methods.keys() { + let method_obj = obj.getattr(name.to_string().as_str())?; + let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?; + if let Ok(method_ty) = method_ty { + module_attributes.insert(*name, (method_ty, true)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + let module_ty = + TypeEnum::TModule { module_id: *module_id, attributes: module_attributes }; + + let ty = unifier.add_ty(module_ty); + return Ok(Ok(ty)); + } + if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); return Ok(Ok(ty)); @@ -1373,6 +1415,77 @@ impl InnerResolver { None => Ok(None), } } + } else if ty_id == self.primitive_ids.module { + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let fields = { + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read(); + let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() }; + attributes + .iter() + .filter_map(|f| { + let definition = top_level_defs.get(f.1 .0).unwrap().read(); + if let TopLevelDef::Variable { ty, .. } = &*definition { + Some((f.0, *ty)) + } else { + None + } + }) + .collect_vec() + }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty)| { + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } } else { let id_str = id.to_string(); @@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx>( &self, id: StrRef, - _: &mut CodeGenContext<'ctx, '_>, - _: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { + if let Some(def_id) = self.0.id_to_def.read().get(&id) { + let top_levels = ctx.top_level.definitions.read(); + if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { + let module_val = &self.0.module; + let ret = Python::with_gil(|py| -> PyResult> { + let module_val = module_val.as_ref(py); + + let ty = self.0.get_obj_type( + py, + module_val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; + if let Err(ty) = ty { + return Ok(Err(ty)); + } + let ty = ty.unwrap(); + let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap(); + let (idx, _) = ctx.get_attr_index(ty, id); + let ret = unsafe { + ctx.builder.build_gep( + obj.into_pointer_value(), + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(idx as u64, false), + ], + id.to_string().as_str(), + ) + } + .unwrap(); + Ok(Ok(ret.as_basic_value_enum())) + }) + .unwrap(); + if ret.is_err() { + return None; + } + return Some(ret.unwrap().into()); + } + } + let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a4ca27f1..a6a0ce76 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,8 +101,9 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } - | TopLevelDef::Module { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { + name.to_string() + } TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) @@ -205,29 +206,37 @@ impl TopLevelComposer { /// register top level modules pub fn register_top_level_module( &mut self, - module_name: String, - name_to_pyid: Rc>, + module_name: &str, + name_to_pyid: &Rc>, resolver: Arc, - location: Option + location: Option, ) -> Result { - let mut attributes: HashMap = HashMap::new(); + let mut methods: HashMap = HashMap::new(); + let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new(); + for (name, _) in name_to_pyid.iter() { if let Ok(def_id) = resolver.get_identifier_def(*name) { // Avoid repeated attribute instances resulting from multiple imports of same module if self.defined_names.contains(&format!("{module_name}.{name}")) { - attributes.insert(*name, def_id); + match &*self.definition_ast_list[def_id.0].0.read() { + TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => { + methods.insert(*name, def_id); + } + _ => attributes.push((*name, def_id)), + } } }; } - let module_def = TopLevelDef::Module { - name: module_name.clone().into(), - module_id: DefinitionId(self.definition_ast_list.len()), - attributes, - resolver: Some(resolver), - loc: location + let module_def = TopLevelDef::Module { + name: module_name.to_string().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + methods, + attributes, + resolver: Some(resolver), + loc: location, }; - self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None)); Ok(DefinitionId(self.definition_ast_list.len() - 1)) } @@ -499,10 +508,10 @@ impl TopLevelComposer { self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + self.analyze_top_level_variables()?; if inference { self.analyze_function_instance()?; } - self.analyze_top_level_variables()?; Ok(()) } @@ -1440,7 +1449,7 @@ impl TopLevelComposer { Ok(()) } - /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -1971,7 +1980,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 5. Analyze and populate the types of global variables. + /// Step 4. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -1989,6 +1998,19 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72502aa4..4ca5464f 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,11 +379,12 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Module { name, attributes, .. } => { - let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec(); + TopLevelDef::Module { name, attributes, methods, .. } => { format!( - "Module {{\nname: {:?},\nattributes{:?}\n}}", - name, method_str + "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}", + name, + attributes.iter().map(|(n, _)| n.to_string()).collect_vec(), + methods.iter().map(|(n, _)| n.to_string()).collect_vec() ) } TopLevelDef::Class { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 88c007ec..3ffd568a 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -97,8 +97,10 @@ pub enum TopLevelDef { name: StrRef, /// Module ID used for [`TypeEnum`] module_id: DefinitionId, - /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module - attributes: HashMap, + /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module + methods: HashMap, + /// `DefinitionId` of `TopLevelDef::{Variable}` within the module + attributes: Vec<(StrRef, DefinitionId)>, /// Symbol resolver of the module defined the class. resolver: Option>, /// Definition location. From 5fdbc34b430bd5875623eb5a0e0839d99422555d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 11:11:53 +0800 Subject: [PATCH 65/80] [core] implement codegen for modules --- nac3artiq/src/codegen.rs | 28 +++++++++++++++++++++++ nac3core/src/codegen/concrete_type.rs | 13 +++++++++++ nac3core/src/codegen/expr.rs | 26 +++++++++++++++++----- nac3core/src/codegen/mod.rs | 32 +++++++++++++++++++++++++++ nac3core/src/symbol_resolver.rs | 8 ++++--- 5 files changed, 98 insertions(+), 9 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index c968198b..e4727056 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>( )); } } + TypeEnum::TModule { attributes, .. } => { + let mut fields = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + for (name, (field_ty, is_method)) in attributes { + if *is_method { + continue; + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + fields.push(name.to_string()); + let (index, _) = ctx.get_attr_index(ty, *name); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); + } + } + if !fields.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", fields)?; + host_attributes.append(pydict)?; + } + } _ => {} } } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index f0c92ed8..503a4ae8 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum { fields: HashMap, params: IndexMap, }, + TModule { + module_id: DefinitionId, + methods: HashMap, + }, TVirtual { ty: ConcreteType, }, @@ -297,6 +301,15 @@ impl ConcreteTypeStore { TypeVar { id, ty } })), }, + ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule { + module_id: *module_id, + attributes: methods + .iter() + .map(|(name, cty)| { + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) + }) + .collect::>(), + }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args .iter() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 6d2057e1..fd2cd286 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -61,8 +61,13 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.clone() + } else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) { + indexmap::IndexMap::new() + } else { + unreachable!() + } }) .unwrap_or_default(); vars.extend(fun_vars); @@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, + TypeEnum::TModule { module_id, .. } => *module_id, // we cannot have other types, virtual type should be handled by function calls _ => codegen_unreachable!(self), }; @@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> { let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); (attribute_index.0, Some(attribute_index.1 .2.clone())) } + } else if let TopLevelDef::Module { attributes, .. } = &*def.read() { + (attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None) } else { codegen_unreachable!(self) }; @@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id + } else if let TypeEnum::TModule { module_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *module_id } else { codegen_unreachable!(ctx) }; @@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { + if let TopLevelDef::Class { methods, .. } = &*obj_def { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } else if let TopLevelDef::Module { methods, .. } = &*obj_def { + *methods.iter().find(|method| method.0 == attr).unwrap().1 + } else { codegen_unreachable!(ctx) - }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index dcfa2b8c..37e1bb33 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { + TModule {module_id, attributes} => { + let top_level_defs = top_level.definitions.read(); + let definition = top_level_defs.get(module_id.0).unwrap(); + let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else { + unreachable!() + }; + let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name.to_string()); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into(), + ); + let module_fields: Vec> = attribute_fields.iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + attributes[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&module_fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty; + }, TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes if PrimDef::contains_id(*obj_id) { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2378dd62..48290935 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { - unreachable!("expected class definition") + let top_level_def = &*top_level_defs[id].read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + top_level_def + else { + unreachable!("expected class/module definition") }; - name.to_string() }, &mut |id| format!("typevar{id}"), From 14e80dfab7dbd0522f5761c16a3fb30650c43951 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 12:41:30 +0800 Subject: [PATCH 66/80] update snapshots --- ...c3core__toplevel__test__test_analyze__generic_class.snap | 4 ++-- ..._toplevel__test__test_analyze__inheritance_override.snap | 6 +++--- ...e__toplevel__test__test_analyze__list_tuple_generic.snap | 4 ++-- .../nac3core__toplevel__test__test_analyze__self1.snap | 4 ++-- ..._toplevel__test__test_analyze__simple_class_compose.snap | 6 +++--- ..._toplevel__test__test_analyze__simple_pass_in_class.snap | 4 +--- 6 files changed, 13 insertions(+), 15 deletions(-) diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 4332b474..8c827eed 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 60e0c194..b8a80a5c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 46601817..05f44884 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", - "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index da58d121..7d3922e7 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 8f384fa1..b55e9985 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -3,14 +3,14 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap index 5178f1b4..2f37789c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap @@ -1,9 +1,7 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 549 expression: res_vec - --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n", ] From 879b063968235b2c14950baed75d2b539311d235 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 12:42:13 +0800 Subject: [PATCH 67/80] [artiq] add tests for module support --- nac3artiq/demo/module_support.py | 29 +++++++++++++++++++ nac3artiq/demo/tests/global_variables.py | 14 +++++++++ .../{ => tests}/string_attribute_issue337.py | 12 ++------ .../support_class_attr_issue102.py | 3 -- 4 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 nac3artiq/demo/module_support.py create mode 100644 nac3artiq/demo/tests/global_variables.py rename nac3artiq/demo/{ => tests}/string_attribute_issue337.py (57%) rename nac3artiq/demo/{ => tests}/support_class_attr_issue102.py (99%) diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py new file mode 100644 index 00000000..a863b380 --- /dev/null +++ b/nac3artiq/demo/module_support.py @@ -0,0 +1,29 @@ +from min_artiq import * +import tests.string_attribute_issue337 as issue337 +import tests.support_class_attr_issue102 as issue102 +import tests.global_variables as global_variables + +@nac3 +class TestModuleSupport: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + # Accessing classes + issue337.Demo().run() + obj = issue102.Demo() + obj.attr3 = 3 + + # Calling functions + global_variables.inc_X() + global_variables.display_X() + + # Updating global variables + global_variables.X = 9 + global_variables.display_X() + +if __name__ == "__main__": + TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/tests/global_variables.py b/nac3artiq/demo/tests/global_variables.py new file mode 100644 index 00000000..ac0e0cf0 --- /dev/null +++ b/nac3artiq/demo/tests/global_variables.py @@ -0,0 +1,14 @@ +from min_artiq import * +from numpy import int32 + +X: Kernel[int32] = 1 + +@rpc +def display_X(): + print_int32(X) + +@kernel +def inc_X(): + global X + X += 1 + diff --git a/nac3artiq/demo/string_attribute_issue337.py b/nac3artiq/demo/tests/string_attribute_issue337.py similarity index 57% rename from nac3artiq/demo/string_attribute_issue337.py rename to nac3artiq/demo/tests/string_attribute_issue337.py index 9749462a..c0b36ed6 100644 --- a/nac3artiq/demo/string_attribute_issue337.py +++ b/nac3artiq/demo/tests/string_attribute_issue337.py @@ -1,16 +1,13 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: - core: KernelInvariant[Core] - attr1: KernelInvariant[str] - attr2: KernelInvariant[int32] - + attr1: Kernel[str] + attr2: Kernel[int32] + @kernel def __init__(self): - self.core = Core() self.attr2 = 32 self.attr1 = "SAMPLE" @@ -19,6 +16,3 @@ class Demo: print_int32(self.attr2) self.attr1 - -if __name__ == "__main__": - Demo().run() diff --git a/nac3artiq/demo/support_class_attr_issue102.py b/nac3artiq/demo/tests/support_class_attr_issue102.py similarity index 99% rename from nac3artiq/demo/support_class_attr_issue102.py rename to nac3artiq/demo/tests/support_class_attr_issue102.py index 1b931444..0482e3f1 100644 --- a/nac3artiq/demo/support_class_attr_issue102.py +++ b/nac3artiq/demo/tests/support_class_attr_issue102.py @@ -1,7 +1,6 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: attr1: KernelInvariant[int32] = 2 @@ -12,7 +11,6 @@ class Demo: def __init__(self): self.attr3 = 8 - @nac3 class NAC3Devices: core: KernelInvariant[Core] @@ -35,6 +33,5 @@ class NAC3Devices: NAC3Devices.attr4 # Attributes accessible for classes without __init__ - if __name__ == "__main__": NAC3Devices().run() From 2783834cb1e2df6e41ed2689cbd1a2db63ee18d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Fri, 17 Jan 2025 12:45:51 +0800 Subject: [PATCH 68/80] nac3artiq/demo: merge EmbeddingMap into min_artiq --- nac3artiq/demo/embedding_map.py | 39 ------------------------------- nac3artiq/demo/min_artiq.py | 41 ++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 40 deletions(-) delete mode 100644 nac3artiq/demo/embedding_map.py diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py deleted file mode 100644 index a43af698..00000000 --- a/nac3artiq/demo/embedding_map.py +++ /dev/null @@ -1,39 +0,0 @@ -class EmbeddingMap: - def __init__(self): - self.object_inverse_map = {} - self.object_map = {} - self.string_map = {} - self.string_reverse_map = {} - self.function_map = {} - self.attributes_writeback = [] - - def store_function(self, key, fun): - self.function_map[key] = fun - return key - - def store_object(self, obj): - obj_id = id(obj) - if obj_id in self.object_inverse_map: - return self.object_inverse_map[obj_id] - key = len(self.object_map) + 1 - self.object_map[key] = obj - self.object_inverse_map[obj_id] = key - return key - - def store_str(self, s): - if s in self.string_reverse_map: - return self.string_reverse_map[s] - key = len(self.string_map) - self.string_map[key] = s - self.string_reverse_map[s] = key - return key - - def retrieve_function(self, key): - return self.function_map[key] - - def retrieve_object(self, key): - return self.object_map[key] - - def retrieve_str(self, key): - return self.string_map[key] - diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 62d32cc3..fef018b2 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -6,7 +6,6 @@ from typing import Generic, TypeVar from math import floor, ceil import nac3artiq -from embedding_map import EmbeddingMap __all__ = [ @@ -193,6 +192,46 @@ def print_int64(x: int64): raise NotImplementedError("syscall not simulated") +class EmbeddingMap: + def __init__(self): + self.object_inverse_map = {} + self.object_map = {} + self.string_map = {} + self.string_reverse_map = {} + self.function_map = {} + self.attributes_writeback = [] + + def store_function(self, key, fun): + self.function_map[key] = fun + return key + + def store_object(self, obj): + obj_id = id(obj) + if obj_id in self.object_inverse_map: + return self.object_inverse_map[obj_id] + key = len(self.object_map) + 1 + self.object_map[key] = obj + self.object_inverse_map[obj_id] = key + return key + + def store_str(self, s): + if s in self.string_reverse_map: + return self.string_reverse_map[s] + key = len(self.string_map) + self.string_map[key] = s + self.string_reverse_map[s] = key + return key + + def retrieve_function(self, key): + return self.function_map[key] + + def retrieve_object(self, key): + return self.object_map[key] + + def retrieve_str(self, key): + return self.string_map[key] + + @nac3 class Core: ref_period: KernelInvariant[float] From 2d275949b8fc274afe7d441f63b90db3556d1bfd Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 17 Jan 2025 13:04:16 +0800 Subject: [PATCH 69/80] move tests from artiq to standalone --- .../demo/tests/string_attribute_issue337.py | 18 --------- .../demo/tests/support_class_attr_issue102.py | 37 ------------------- nac3standalone/demo/src/class_attributes.py | 35 ++++++++++++++++++ 3 files changed, 35 insertions(+), 55 deletions(-) delete mode 100644 nac3artiq/demo/tests/string_attribute_issue337.py delete mode 100644 nac3artiq/demo/tests/support_class_attr_issue102.py create mode 100644 nac3standalone/demo/src/class_attributes.py diff --git a/nac3artiq/demo/tests/string_attribute_issue337.py b/nac3artiq/demo/tests/string_attribute_issue337.py deleted file mode 100644 index c0b36ed6..00000000 --- a/nac3artiq/demo/tests/string_attribute_issue337.py +++ /dev/null @@ -1,18 +0,0 @@ -from min_artiq import * -from numpy import int32 - -@nac3 -class Demo: - attr1: Kernel[str] - attr2: Kernel[int32] - - @kernel - def __init__(self): - self.attr2 = 32 - self.attr1 = "SAMPLE" - - @kernel - def run(self): - print_int32(self.attr2) - self.attr1 - diff --git a/nac3artiq/demo/tests/support_class_attr_issue102.py b/nac3artiq/demo/tests/support_class_attr_issue102.py deleted file mode 100644 index 0482e3f1..00000000 --- a/nac3artiq/demo/tests/support_class_attr_issue102.py +++ /dev/null @@ -1,37 +0,0 @@ -from min_artiq import * -from numpy import int32 - -@nac3 -class Demo: - attr1: KernelInvariant[int32] = 2 - attr2: int32 = 4 - attr3: Kernel[int32] - - @kernel - def __init__(self): - self.attr3 = 8 - -@nac3 -class NAC3Devices: - core: KernelInvariant[Core] - attr4: KernelInvariant[int32] = 16 - - def __init__(self): - self.core = Core() - - @kernel - def run(self): - Demo.attr1 # Supported - # Demo.attr2 # Field not accessible on Kernel - # Demo.attr3 # Only attributes can be accessed in this way - # Demo.attr1 = 2 # Attributes are immutable - - self.attr4 # Attributes can be accessed within class - - obj = Demo() - obj.attr1 # Attributes can be accessed by class objects - - NAC3Devices.attr4 # Attributes accessible for classes without __init__ - -if __name__ == "__main__": - NAC3Devices().run() diff --git a/nac3standalone/demo/src/class_attributes.py b/nac3standalone/demo/src/class_attributes.py new file mode 100644 index 00000000..b58958fa --- /dev/null +++ b/nac3standalone/demo/src/class_attributes.py @@ -0,0 +1,35 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_strln(x: str): + ... + + +class A: + a: int32 = 1 + b: int32 + c: str = "test" + d: str + + def __init__(self): + self.b = 2 + self.d = "test" + + output_int32(self.a) # Attributes can be accessed within class + + +def run() -> int32: + output_int32(A.a) # Attributes can be directly accessed with class name + # A.b # Only attributes can be accessed in this way + # A.a = 2 # Attributes are immutable + + obj = A() + output_int32(obj.a) # Attributes can be accessed by class objects + + output_strln(obj.c) + output_strln(obj.d) + + return 0 + From f817d3347bb4b385754141a50029ebf8e03a9912 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 17 Jan 2025 17:56:43 +0800 Subject: [PATCH 70/80] [artiq] cleanup module functionality tests --- nac3artiq/demo/module.py | 26 ++++++++++++++++++++++++ nac3artiq/demo/module_support.py | 17 +++++++--------- nac3artiq/demo/tests/global_variables.py | 14 ------------- 3 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 nac3artiq/demo/module.py delete mode 100644 nac3artiq/demo/tests/global_variables.py diff --git a/nac3artiq/demo/module.py b/nac3artiq/demo/module.py new file mode 100644 index 00000000..58f92450 --- /dev/null +++ b/nac3artiq/demo/module.py @@ -0,0 +1,26 @@ +from min_artiq import * +from numpy import int32 + +# Global Variable Definition +X: Kernel[int32] = 1 + +# TopLevelFunction Defintion +@kernel +def display_X(): + print_int32(X) + +# TopLevel Class Definition +@nac3 +class A: + @kernel + def __init__(self): + self.set_x(1) + + @kernel + def set_x(self, new_val: int32): + global X + X = new_val + + @kernel + def get_X(self) -> int32: + return X diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py index a863b380..78ef6565 100644 --- a/nac3artiq/demo/module_support.py +++ b/nac3artiq/demo/module_support.py @@ -1,7 +1,5 @@ from min_artiq import * -import tests.string_attribute_issue337 as issue337 -import tests.support_class_attr_issue102 as issue102 -import tests.global_variables as global_variables +import module as module_definition @nac3 class TestModuleSupport: @@ -13,17 +11,16 @@ class TestModuleSupport: @kernel def run(self): # Accessing classes - issue337.Demo().run() - obj = issue102.Demo() - obj.attr3 = 3 + obj = module_definition.A() + obj.get_X() + obj.set_x(2) # Calling functions - global_variables.inc_X() - global_variables.display_X() + module_definition.display_X() # Updating global variables - global_variables.X = 9 - global_variables.display_X() + module_definition.X = 9 + module_definition.display_X() if __name__ == "__main__": TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/tests/global_variables.py b/nac3artiq/demo/tests/global_variables.py deleted file mode 100644 index ac0e0cf0..00000000 --- a/nac3artiq/demo/tests/global_variables.py +++ /dev/null @@ -1,14 +0,0 @@ -from min_artiq import * -from numpy import int32 - -X: Kernel[int32] = 1 - -@rpc -def display_X(): - print_int32(X) - -@kernel -def inc_X(): - global X - X += 1 - From 05fd1a519902b3a8c050ec031ca7496b32626339 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Jan 2025 22:09:03 +0800 Subject: [PATCH 71/80] [meta] Use lld as linker --- .cargo/config.toml | 2 ++ flake.nix | 5 +++-- nac3standalone/demo/run_demo.sh | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..188308db --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "link-arg=-fuse-ld=lld"] diff --git a/flake.nix b/flake.nix index a20e1f12..a48ff69b 100644 --- a/flake.nix +++ b/flake.nix @@ -41,7 +41,7 @@ lockFile = ./Cargo.lock; }; passthru.cargoLock = cargoLock; - nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; + nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.bintools llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkPhase = @@ -120,6 +120,7 @@ buildInputs = [ (python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ])) pkgs.llvmPackages_14.llvm.out + pkgs.llvmPackages_14.bintools ]; phases = [ "buildPhase" "installPhase" ]; buildPhase = @@ -168,7 +169,7 @@ buildInputs = with pkgs; [ # build dependencies packages.x86_64-linux.llvm-nac3 - (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos + (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out llvmPackages_14.bintools # for running nac3standalone demos packages.x86_64-linux.llvm-tools-irrt cargo rustc diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index bec2eb6e..78e32dd2 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -58,7 +58,7 @@ rm -f ./*.o ./*.bc demo if [ -z "$i686" ]; then $nac3standalone "${nac3args[@]}" clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c - clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch + clang -o demo module.o demo.o $DEMO_LINALG_STUB -fuse-ld=lld -lm else $nac3standalone --triple i686-unknown-linux-gnu --target-features +sse2 "${nac3args[@]}" clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c From 37df08b803c4c58b1d95eb735d9d1d4b97910e10 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 11:01:30 +0800 Subject: [PATCH 72/80] [meta] Update dependencies --- Cargo.lock | 168 ++++++++++++++++++++++--------------------- nac3artiq/Cargo.toml | 4 +- nac3ast/Cargo.toml | 2 +- nac3core/Cargo.toml | 7 +- 4 files changed, 94 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1d93528..9435feda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -9,7 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -127,9 +127,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.9" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "shlex", ] @@ -142,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -200,9 +200,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -334,9 +334,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "fxhash" @@ -374,7 +380,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] @@ -389,20 +407,14 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "foldhash", +] [[package]] name = "heck" @@ -437,9 +449,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -498,9 +510,9 @@ checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] @@ -522,9 +534,9 @@ dependencies = [ [[package]] name = "lalrpop" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f" +checksum = "7047a26de42016abf8f181b46b398aef0b77ad46711df41847f6ed869a2a1d5b" dependencies = [ "ascii-canvas", "bit-set", @@ -544,9 +556,9 @@ dependencies = [ [[package]] name = "lalrpop-util" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce" +checksum = "e8d05b3fe34b8bd562c338db725dfa9beb9451a48f65f129ccb9538b48d2c93b" dependencies = [ "regex-automata", "rustversion", @@ -656,7 +668,7 @@ name = "nac3core" version = "0.1.0" dependencies = [ "crossbeam", - "indexmap 2.7.0", + "indexmap 2.7.1", "indoc", "inkwell", "insta", @@ -664,7 +676,6 @@ dependencies = [ "nac3core_derive", "nac3parser", "parking_lot", - "rayon", "regex", "strum", "strum_macros", @@ -752,12 +763,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap 2.7.1", ] [[package]] @@ -980,27 +991,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", + "getrandom 0.2.15", ] [[package]] @@ -1050,9 +1041,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -1069,9 +1060,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "same-file" @@ -1090,9 +1081,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" @@ -1116,9 +1107,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" dependencies = [ "itoa", "memchr", @@ -1165,9 +1156,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "similar" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" @@ -1189,12 +1180,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "string-interner" -version = "0.17.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e" +checksum = "1a3275464d7a9f2d4cac57c89c2ef96a8524dba2864c8d6f82e3980baf136f9b" dependencies = [ - "cfg-if", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "serde", ] @@ -1272,13 +1262,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1363,7 +1353,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.7.1", "serde", "serde_spanned", "toml_datetime", @@ -1372,9 +1362,9 @@ dependencies = [ [[package]] name = "trybuild" -version = "1.0.101" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4" +checksum = "b812699e0c4f813b872b373a4471717d9eb550da14b311058a4d9cf4173cbca6" dependencies = [ "dissimilar", "glob", @@ -1446,9 +1436,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-width" @@ -1518,6 +1508,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "winapi-util" version = "0.1.9" @@ -1611,13 +1610,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "ad699df48212c6cc6eb4435f35500ac6fd3b9913324f938aea302022ce19d310" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/nac3artiq/Cargo.toml b/nac3artiq/Cargo.toml index 2da812ed..fa804659 100644 --- a/nac3artiq/Cargo.toml +++ b/nac3artiq/Cargo.toml @@ -9,10 +9,10 @@ name = "nac3artiq" crate-type = ["cdylib"] [dependencies] -itertools = "0.13" +itertools = "0.14" pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } parking_lot = "0.12" -tempfile = "3.13" +tempfile = "3.16" nac3core = { path = "../nac3core" } nac3ld = { path = "../nac3ld" } diff --git a/nac3ast/Cargo.toml b/nac3ast/Cargo.toml index dc2bd558..947be09a 100644 --- a/nac3ast/Cargo.toml +++ b/nac3ast/Cargo.toml @@ -11,5 +11,5 @@ fold = [] [dependencies] parking_lot = "0.12" -string-interner = "0.17" +string-interner = "0.18" fxhash = "0.2" diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 6521a334..7badcee9 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -10,11 +10,10 @@ derive = ["dep:nac3core_derive"] no-escape-analysis = [] [dependencies] -itertools = "0.13" +itertools = "0.14" crossbeam = "0.8" -indexmap = "2.6" +indexmap = "2.7" parking_lot = "0.12" -rayon = "1.10" nac3core_derive = { path = "nac3core_derive", optional = true } nac3parser = { path = "../nac3parser" } strum = "0.26" @@ -31,4 +30,4 @@ indoc = "2.0" insta = "=1.11.0" [build-dependencies] -regex = "1.10" +regex = "1.11" From bdeeced1223a29db30bd5a92f48e627e1c953b16 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:55:49 +0800 Subject: [PATCH 73/80] [core] codegen: Normalize RangeType factory functions Better matches factory functions of other ProxyTypes. --- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/test.rs | 3 ++- nac3core/src/codegen/types/range.rs | 38 +++++++++++++++++++++++------ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 37e1bb33..73a28b7a 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -800,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new(context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_base_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index a58a9847..01672c55 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -453,8 +453,9 @@ fn test_classes_list_type_new() { #[test] fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); - let llvm_range = RangeType::new(&ctx); + let llvm_range = RangeType::new_with_generator(&generator, &ctx); assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); } diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index bdd4e79c..b92d7658 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -5,9 +5,12 @@ use inkwell::{ }; use super::ProxyType; -use crate::codegen::{ - values::{ProxyValue, RangeValue}, - {CodeGenContext, CodeGenerator}, +use crate::{ + codegen::{ + values::{ProxyValue, RangeValue}, + {CodeGenContext, CodeGenerator}, + }, + typecheck::typedef::{Type, TypeEnum}, }; /// Proxy type for a `range` type in LLVM. @@ -54,12 +57,33 @@ impl<'ctx> RangeType<'ctx> { llvm_i32.array_type(3).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &'ctx Context) -> Self { + fn new_impl(ctx: &'ctx Context) -> Self { let llvm_range = Self::llvm_type(ctx); - RangeType::from_type(llvm_range) + RangeType { ty: llvm_range } + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + RangeType::new_impl(ctx.ctx) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new_with_generator(_: &G, ctx: &'ctx Context) -> Self { + Self::new_impl(ctx) + } + + /// Creates an [`RangeType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new(ctx) } /// Creates an [`RangeType`] from a [`PointerType`]. From 87a637b448db0613c8cc8b9a3a1d9cb91e8fa1f4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:45:45 +0800 Subject: [PATCH 74/80] [core] codegen: Refactor Proxy{Type,Value} for StructProxy{Type,Value} --- nac3artiq/src/codegen.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 6 +- nac3core/src/codegen/expr.rs | 4 +- .../src/codegen/irrt/ndarray/broadcast.rs | 7 +- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/codegen/test.rs | 4 +- nac3core/src/codegen/types/list.rs | 65 ++--- nac3core/src/codegen/types/mod.rs | 18 +- .../src/codegen/types/ndarray/broadcast.rs | 53 ++-- .../src/codegen/types/ndarray/contiguous.rs | 69 +++--- .../src/codegen/types/ndarray/indexing.rs | 47 ++-- nac3core/src/codegen/types/ndarray/mod.rs | 49 ++-- nac3core/src/codegen/types/ndarray/nditer.rs | 49 ++-- nac3core/src/codegen/types/range.rs | 233 +++++++++--------- nac3core/src/codegen/types/tuple.rs | 22 +- nac3core/src/codegen/types/utils/slice.rs | 97 ++++---- nac3core/src/codegen/values/list.rs | 11 +- nac3core/src/codegen/values/mod.rs | 20 +- .../src/codegen/values/ndarray/broadcast.rs | 11 +- .../src/codegen/values/ndarray/contiguous.rs | 11 +- .../src/codegen/values/ndarray/indexing.rs | 11 +- nac3core/src/codegen/values/ndarray/mod.rs | 11 +- nac3core/src/codegen/values/ndarray/nditer.rs | 11 +- nac3core/src/codegen/values/range.rs | 23 +- nac3core/src/codegen/values/tuple.rs | 11 +- nac3core/src/codegen/values/utils/slice.rs | 11 +- nac3core/src/toplevel/builtins.rs | 6 +- 27 files changed, 342 insertions(+), 530 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index e4727056..4c86028f 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -19,9 +19,9 @@ use nac3core::{ llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -1431,7 +1431,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeValue::from_pointer_value(value.into_pointer_value(), None); + let val = RangeType::new(ctx).map_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 911e3dc1..6cacac45 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -11,10 +11,10 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - types::{ndarray::NDArrayType, ListType, TupleType}, + types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, values::{ ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, - ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; @@ -47,7 +47,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let range_ty = ctx.primitives.range; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range")); + let arg = RangeType::new(ctx).map_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index fd2cd286..4da0ef33 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType}, + types::{ndarray::NDArrayType, ListType, RangeType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -1151,7 +1151,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index fceba25f..a7d40a57 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -55,10 +55,9 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( let llvm_usize = ctx.get_size_type(); assert_eq!(num_shape_entries.get_type(), llvm_usize); - assert!(ShapeEntryType::is_type( - generator, - ctx.ctx, - shape_entries.base_ptr(ctx, generator).get_type() + assert!(ShapeEntryType::is_representable( + shape_entries.base_ptr(ctx, generator).get_type(), + llvm_usize, ) .is_ok()); assert_eq!(dst_ndims.get_type(), llvm_usize); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 7b99bc26..e8f1d906 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -17,10 +17,10 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ ndarray::{RustNDIndex, ScalarOrNDArray}, - ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, }, CodeGenContext, CodeGenerator, }; @@ -511,7 +511,7 @@ pub fn gen_for( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 01672c55..ecc0ba96 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -455,8 +455,10 @@ fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); + let llvm_usize = generator.get_size_type(&ctx); + let llvm_range = RangeType::new_with_generator(&generator, &ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_base_type(), llvm_usize).is_ok()); } #[test] diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 637cced3..60015b8c 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -56,34 +56,6 @@ impl<'ctx> ListStructFields<'ctx> { } impl<'ctx> ListType<'ctx> { - /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); - }; - - let fields = ListStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields( - fields, - llvm_ty, - "list", - &[(fields.items.name(), &|ty| { - if ty.is_pointer_type() { - Ok(()) - } else { - Err(format!("Expected T* for `list.items`, got {ty}")) - } - })], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { @@ -184,7 +156,7 @@ impl<'ctx> ListType<'ctx> { /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); let ctx = ptr_ty.get_context(); @@ -336,24 +308,39 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); + }; + + let fields = ListStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 0a31d6a5..5865d636 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -17,8 +17,7 @@ //! on the stack. use inkwell::{ - context::Context, - types::BasicType, + types::{BasicType, IntType}, values::{IntValue, PointerValue}, }; @@ -46,18 +45,15 @@ pub trait ProxyType<'ctx>: Into { /// The type of values represented by this type. type Value: ProxyValue<'ctx, Type = Self>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String>; - /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String>; + /// Checks whether the type represented by `ty` expresses the same type represented by this + /// [`ProxyType`]. + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String>; /// Returns the type that should be used in `alloca` IR statements. fn alloca_type(&self) -> impl BasicType<'ctx>; diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 3a1fd8da..af1a26fa 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -32,28 +32,6 @@ pub struct ShapeEntryStructFields<'ctx> { } impl<'ctx> ShapeEntryType<'ctx> { - /// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!( - "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" - )); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDArray", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -103,7 +81,7 @@ impl<'ctx> ShapeEntryType<'ctx> { /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -152,24 +130,33 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { type Base = PointerType<'ctx>; type Value = ShapeEntryValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index c751d573..1987ab6d 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -58,36 +58,6 @@ impl<'ctx> ContiguousNDArrayStructFields<'ctx> { } impl<'ctx> ContiguousNDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields( - fields, - llvm_ty, - "ContiguousNDArray", - &[(fields.data.name(), &|ty| { - if ty.is_pointer_type() { - Ok(()) - } else { - Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) - } - })], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -160,7 +130,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, item, llvm_usize } } @@ -222,24 +192,41 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "ContiguousNDArray", + &[(fields.data.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) + } + })], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 3e4e1362..8e15c903 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -35,25 +35,6 @@ pub struct NDIndexStructFields<'ctx> { } impl<'ctx> NDIndexType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndindex` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = NDIndexStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) - } - #[must_use] fn fields( ctx: impl AsContextRef<'ctx>, @@ -96,7 +77,7 @@ impl<'ctx> NDIndexType<'ctx> { #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -180,24 +161,30 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = NDIndexStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index fe73307d..1743fe2b 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -62,26 +62,6 @@ pub struct NDArrayStructFields<'ctx> { } impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDArray", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -211,7 +191,7 @@ impl<'ctx> NDArrayType<'ctx> { ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize } } @@ -450,24 +430,31 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 45b6bb0a..6246eef2 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -44,26 +44,6 @@ pub struct NDIterStructFields<'ctx> { } impl<'ctx> NDIterType<'ctx> { - /// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { - return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDIter", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> { @@ -110,7 +90,7 @@ impl<'ctx> NDIterType<'ctx> { /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -208,24 +188,31 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { + return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDIter", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index b92d7658..158152bf 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -17,12 +17,125 @@ use crate::{ #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct RangeType<'ctx> { ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, } impl<'ctx> RangeType<'ctx> { - /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. - pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> { - let llvm_range_ty = llvm_ty.get_element_type(); + /// Creates an LLVM type corresponding to the expected structure of a `Range`. + #[must_use] + fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { + // typedef int32_t Range[3]; + let llvm_i32 = ctx.i32_type(); + llvm_i32.array_type(3).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_range = Self::llvm_type(ctx); + + RangeType { ty: llvm_range, llvm_usize } + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`RangeType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new(ctx) + } + + /// Creates an [`RangeType`] from a [`PointerType`]. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + RangeType { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of all fields of this `range` type. + #[must_use] + pub fn value_type(&self) -> IntType<'ctx> { + self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type Base = PointerType<'ctx>; + type Value = RangeValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + let llvm_range_ty = ty.get_element_type(); let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); }; @@ -49,120 +162,6 @@ impl<'ctx> RangeType<'ctx> { Ok(()) } - /// Creates an LLVM type corresponding to the expected structure of a `Range`. - #[must_use] - fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { - // typedef int32_t Range[3]; - let llvm_i32 = ctx.i32_type(); - llvm_i32.array_type(3).ptr_type(AddressSpace::default()) - } - - fn new_impl(ctx: &'ctx Context) -> Self { - let llvm_range = Self::llvm_type(ctx); - - RangeType { ty: llvm_range } - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { - RangeType::new_impl(ctx.ctx) - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new_with_generator(_: &G, ctx: &'ctx Context) -> Self { - Self::new_impl(ctx) - } - - /// Creates an [`RangeType`] from a [unifier type][Type]. - #[must_use] - pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { - // Check unifier type - assert!( - matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) - ); - - Self::new(ctx) - } - - /// Creates an [`RangeType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty).is_ok()); - - RangeType { ty: ptr_ty } - } - - /// Returns the type of all fields of this `range` type. - #[must_use] - pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() - } - - /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. - /// - /// See [`ProxyType::raw_alloca`]. - #[must_use] - pub fn alloca( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value(self.raw_alloca(ctx, name), name) - } - - /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. - /// - /// See [`ProxyType::raw_alloca_var`]. - #[must_use] - pub fn alloca_var( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value( - self.raw_alloca_var(generator, ctx, name), - name, - ) - } - - /// Converts an existing value into a [`RangeValue`]. - #[must_use] - pub fn map_value( - &self, - value: <>::Value as ProxyValue<'ctx>>::Base, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value(value, name) - } -} - -impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { - type Base = PointerType<'ctx>; - type Value = RangeValue<'ctx>; - - fn is_type( - generator: &G, - ctx: &'ctx Context, - llvm_ty: impl BasicType<'ctx>, - ) -> Result<(), String> { - if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) - } else { - Err(format!("Expected pointer type, got {llvm_ty:?}")) - } - } - - fn is_representable( - _: &G, - _: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty) - } - fn alloca_type(&self) -> impl BasicType<'ctx> { self.as_base_type().get_element_type().into_struct_type() } diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 5c736528..d05b7f26 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -21,11 +21,6 @@ pub struct TupleType<'ctx> { } impl<'ctx> TupleType<'ctx> { - /// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not. - pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> { - Ok(()) - } - /// Creates an LLVM type corresponding to the expected structure of a tuple. #[must_use] fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { @@ -83,7 +78,7 @@ impl<'ctx> TupleType<'ctx> { /// Creates an [`TupleType`] from a [`StructType`]. #[must_use] pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(struct_ty).is_ok()); + debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok()); TupleType { ty: struct_ty, llvm_usize } } @@ -165,24 +160,19 @@ impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { type Base = StructType<'ctx>; type Value = TupleValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected struct type, got {llvm_ty:?}")) } } - fn is_representable( - _generator: &G, - _ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty) + fn has_same_repr(_: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + Ok(()) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index 0ef4d1b0..b7fafefa 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -61,50 +61,6 @@ impl<'ctx> SliceFields<'ctx> { } impl<'ctx> SliceType<'ctx> { - /// Checks whether `llvm_ty` represents a `slice` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let fields = SliceFields::new(ctx, llvm_usize); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}")); - }; - - check_struct_type_matches_fields( - fields, - llvm_ty, - "Slice", - &[ - (fields.start.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.start`, got {ty}")) - } - }), - (fields.stop.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.stop`, got {ty}")) - } - }), - (fields.step.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.step`, got {ty}")) - } - }), - ], - ) - } - // TODO: Move this into e.g. StructProxyType #[must_use] pub fn get_fields(&self) -> SliceFields<'ctx> { @@ -156,7 +112,7 @@ impl<'ctx> SliceType<'ctx> { int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, int_ty).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, int_ty).is_ok()); Self { ty: ptr_ty, int_ty, llvm_usize } } @@ -221,24 +177,55 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { type Base = PointerType<'ctx>; type Value = SliceValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let fields = SliceFields::new(ctx, llvm_usize); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields( + fields, + llvm_ty, + "Slice", + &[ + (fields.start.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.start`, got {ty}")) + } + }), + (fields.stop.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.stop`, got {ty}")) + } + }), + (fields.step.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.step`, got {ty}")) + } + }), + ], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 4ba5b6af..075f7f64 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -21,15 +21,6 @@ pub struct ListValue<'ctx> { } impl<'ctx> ListValue<'ctx> { - /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ListType::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -37,7 +28,7 @@ impl<'ctx> ListValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); ListValue { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index c789fe0f..dae10f31 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,7 +1,6 @@ -use inkwell::{context::Context, values::BasicValue}; +use inkwell::{types::IntType, values::BasicValue}; use super::types::ProxyType; -use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use range::*; @@ -24,21 +23,8 @@ pub trait ProxyValue<'ctx>: Into { type Type: ProxyType<'ctx, Value = Self>; /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_instance( - generator: &G, - ctx: &'ctx Context, - value: impl BasicValue<'ctx>, - ) -> Result<(), String> { - Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type()) - } - - /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - value: Self::Base, - ) -> Result<(), String> { - Self::is_instance(generator, ctx, value.as_basic_value_enum()) + fn is_instance(value: impl BasicValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + Self::Type::is_representable(value.as_basic_value_enum().get_type(), llvm_usize) } /// Returns the [type][ProxyType] of this value. diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index b5182a2b..acbd2997 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -26,15 +26,6 @@ pub struct ShapeEntryValue<'ctx> { } impl<'ctx> ShapeEntryValue<'ctx> { - /// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -42,7 +33,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 0fbb85f0..65e80258 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -23,15 +23,6 @@ pub struct ContiguousNDArrayValue<'ctx> { } impl<'ctx> ContiguousNDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -40,7 +31,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, item: dtype, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 60c9c3b7..3b7b8f10 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -30,15 +30,6 @@ pub struct NDIndexValue<'ctx> { } impl<'ctx> NDIndexValue<'ctx> { - /// Checks whether `value` is an instance of `ndindex`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDIndexValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -46,7 +37,7 @@ impl<'ctx> NDIndexValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 1bf5db31..cba35ad2 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -49,15 +49,6 @@ pub struct NDArrayValue<'ctx> { } impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - NDArrayType::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -67,7 +58,7 @@ impl<'ctx> NDArrayValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 86f370e5..5479b929 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -23,15 +23,6 @@ pub struct NDIterValue<'ctx> { } impl<'ctx> NDIterValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -41,7 +32,7 @@ impl<'ctx> NDIterValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, parent, indices, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 7e9976a6..b1a5806a 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -1,4 +1,7 @@ -use inkwell::values::{BasicValueEnum, IntValue, PointerValue}; +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue}, +}; use super::ProxyValue; use crate::codegen::{types::RangeType, CodeGenContext}; @@ -7,21 +10,21 @@ use crate::codegen::{types::RangeType, CodeGenContext}; #[derive(Copy, Clone)] pub struct RangeValue<'ctx> { value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } impl<'ctx> RangeValue<'ctx> { - /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. - pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> { - RangeType::is_representable(value.get_type()) - } - /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] - pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { - debug_assert!(Self::is_representable(ptr).is_ok()); + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - RangeValue { value: ptr, name } + RangeValue { value: ptr, llvm_usize, name } } fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -138,7 +141,7 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type()) + RangeType::from_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 5167e479..4558f18c 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -14,15 +14,6 @@ pub struct TupleValue<'ctx> { } impl<'ctx> TupleValue<'ctx> { - /// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: StructValue<'ctx>, - _llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - TupleType::is_representable(value.get_type()) - } - /// Creates an [`TupleValue`] from a [`StructValue`]. #[must_use] pub fn from_struct_value( @@ -30,7 +21,7 @@ impl<'ctx> TupleValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(value, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(value, llvm_usize).is_ok()); Self { value, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index dffe6cef..df9e4de5 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -24,15 +24,6 @@ pub struct SliceValue<'ctx> { } impl<'ctx> SliceValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`SliceValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -41,7 +32,7 @@ impl<'ctx> SliceValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, int_ty, llvm_usize, name } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e06366c5..1c3b0854 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -17,10 +17,10 @@ use crate::{ builtin_fns, numpy::*, stmt::{exn_constructor, gen_if_callback}, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, - ProxyValue, RangeValue, + ProxyValue, }, }, symbol_resolver::SymbolValue, @@ -577,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeValue::from_pointer_value(zelf, Some("range")); + let zelf = RangeType::new(ctx).map_value(zelf, Some("range")); let mut start = None; let mut stop = None; From 96e98947cccdcdfdd6d69865fdcadb85d116fc3f Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:46:30 +0800 Subject: [PATCH 75/80] [core] codegen: Add StructProxy{Type,Value} --- nac3core/src/codegen/types/structure.rs | 44 +++++++++++++++++++++++- nac3core/src/codegen/values/mod.rs | 1 + nac3core/src/codegen/values/structure.rs | 24 +++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 nac3core/src/codegen/values/structure.rs diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 87781d11..0e35c812 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -2,13 +2,55 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, - types::{BasicTypeEnum, IntType, StructType}, + types::{BasicTypeEnum, IntType, PointerType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + AddressSpace, }; use itertools::Itertools; +use super::ProxyType; use crate::codegen::CodeGenContext; +/// A LLVM type that is used to represent a corresponding structure-like type in NAC3. +pub trait StructProxyType<'ctx>: ProxyType<'ctx, Base = PointerType<'ctx>> { + /// The concrete type of [`StructFields`]. + type StructFields: StructFields<'ctx>; + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][StructType]. + fn has_same_struct_repr( + llvm_ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_pointer_repr(llvm_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][PointerType]. + fn has_same_pointer_repr( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_repr(llvm_ty, llvm_usize) + } + + /// Returns the fields present in this [`StructProxyType`]. + #[must_use] + fn get_fields(&self) -> Self::StructFields; + + /// Returns the [`StructType`]. + #[must_use] + fn get_struct_type(&self) -> StructType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + /// Returns the [`PointerType`] representing this type. + #[must_use] + fn get_pointer_type(&self) -> PointerType<'ctx> { + self.as_base_type() + } +} + /// Trait indicating that the structure is a field-wise representation of an LLVM structure. /// /// # Usage diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index dae10f31..9a246356 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -10,6 +10,7 @@ mod array; mod list; pub mod ndarray; mod range; +pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/values/structure.rs b/nac3core/src/codegen/values/structure.rs new file mode 100644 index 00000000..dfe4543b --- /dev/null +++ b/nac3core/src/codegen/values/structure.rs @@ -0,0 +1,24 @@ +use inkwell::values::{BasicValueEnum, PointerValue, StructValue}; + +use super::ProxyValue; +use crate::codegen::{types::structure::StructProxyType, CodeGenContext}; + +/// An LLVM value that is used to represent a corresponding structure-like value in NAC3. +pub trait StructProxyValue<'ctx>: + ProxyValue<'ctx, Base = PointerValue<'ctx>, Type: StructProxyType<'ctx, Value = Self>> +{ + /// Returns this value as a [`StructValue`]. + #[must_use] + fn get_struct_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructValue<'ctx> { + ctx.builder + .build_load(self.get_pointer_value(ctx), "") + .map(BasicValueEnum::into_struct_value) + .unwrap() + } + + /// Returns this value as a [`PointerValue`]. + #[must_use] + fn get_pointer_value(&self, _: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.as_base_value() + } +} From b521bc0c821643e9d9bc501e0b89544843a6d076 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 10:10:23 +0800 Subject: [PATCH 76/80] [core] codegen: Add Proxy{Type,Value}::as_abi_{type,value} Needed for PtrToOrBasic{Type,Value}. --- nac3artiq/src/codegen.rs | 2 +- nac3artiq/src/symbol_resolver.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 130 +++++++++--------- nac3core/src/codegen/expr.rs | 18 +-- nac3core/src/codegen/irrt/ndarray/array.rs | 4 +- nac3core/src/codegen/irrt/ndarray/basic.rs | 33 +++-- .../src/codegen/irrt/ndarray/broadcast.rs | 2 +- nac3core/src/codegen/irrt/ndarray/indexing.rs | 4 +- nac3core/src/codegen/irrt/ndarray/iter.rs | 8 +- .../src/codegen/irrt/ndarray/transpose.rs | 4 +- nac3core/src/codegen/mod.rs | 8 +- nac3core/src/codegen/numpy.rs | 16 +-- nac3core/src/codegen/test.rs | 6 +- nac3core/src/codegen/types/list.rs | 7 +- nac3core/src/codegen/types/mod.rs | 12 +- nac3core/src/codegen/types/ndarray/array.rs | 10 +- .../src/codegen/types/ndarray/broadcast.rs | 7 +- .../src/codegen/types/ndarray/contiguous.rs | 7 +- .../src/codegen/types/ndarray/indexing.rs | 7 +- nac3core/src/codegen/types/ndarray/mod.rs | 7 +- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- nac3core/src/codegen/types/range.rs | 9 +- nac3core/src/codegen/types/tuple.rs | 5 + nac3core/src/codegen/types/utils/slice.rs | 7 +- nac3core/src/codegen/values/list.rs | 7 +- nac3core/src/codegen/values/mod.rs | 14 +- .../src/codegen/values/ndarray/broadcast.rs | 5 + .../src/codegen/values/ndarray/contiguous.rs | 13 +- .../src/codegen/values/ndarray/indexing.rs | 5 + nac3core/src/codegen/values/ndarray/mod.rs | 19 ++- nac3core/src/codegen/values/ndarray/nditer.rs | 9 +- nac3core/src/codegen/values/range.rs | 11 +- nac3core/src/codegen/values/tuple.rs | 5 + nac3core/src/codegen/values/utils/slice.rs | 5 + nac3core/src/toplevel/builtins.rs | 8 +- 35 files changed, 268 insertions(+), 159 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 4c86028f..2cc54387 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -761,7 +761,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - ndarray.as_base_value().into() + ndarray.as_abi_value(ctx).into() } _ => { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 4b398a9b..06a9400c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1146,7 +1146,7 @@ impl InnerResolver { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1316,7 +1316,7 @@ impl InnerResolver { }; let ndarray = llvm_ndarray - .as_base_type() + .as_abi_type() .get_element_type() .into_struct_type() .const_named_struct(&[ @@ -1328,7 +1328,7 @@ impl InnerResolver { ]); let ndarray_global = ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 6cacac45..20a89d0a 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, IntValue}, + values::{BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -137,7 +137,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -197,7 +197,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -273,7 +273,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -338,7 +338,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -402,7 +402,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -448,7 +448,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -485,7 +485,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -550,7 +550,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -600,7 +600,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -650,7 +650,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -767,7 +767,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1026,7 +1026,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1653,11 +1653,11 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_cholesky( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_qr` linalg function @@ -1699,20 +1699,20 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_qr( ctx, - x1_c.as_base_value().into(), - q_c.as_base_value().into(), - r_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), + r_c.as_abi_value(ctx).into(), None, ); - let q = q.as_base_value().as_basic_value_enum(); - let r = r.as_base_value().as_basic_value_enum(); + let q = q.as_abi_value(ctx); + let r = r.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( ctx, - [q, r], + [q.into(), r.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_svd` linalg function @@ -1760,19 +1760,19 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_svd( ctx, - x1_c.as_base_value().into(), - u_c.as_base_value().into(), - s_c.as_base_value().into(), - vh_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), + s_c.as_abi_value(ctx).into(), + vh_c.as_abi_value(ctx).into(), None, ); - let u = u.as_base_value().as_basic_value_enum(); - let s = s.as_base_value().as_basic_value_enum(); - let vh = vh.as_base_value().as_basic_value_enum(); + let u = u.as_abi_value(ctx); + let s = s.as_abi_value(ctx); + let vh = vh.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) - .construct_from_objects(ctx, [u, s, vh], None); - Ok(tuple.as_base_value().into()) + .construct_from_objects(ctx, [u.into(), s.into(), vh.into()], None); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_inv` linalg function @@ -1800,12 +1800,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_inv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_pinv` linalg function @@ -1845,12 +1845,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_pinv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_lu` linalg function @@ -1892,20 +1892,20 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let u_c = u.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_lu( ctx, - x1_c.as_base_value().into(), - l_c.as_base_value().into(), - u_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + l_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), None, ); - let l = l.as_base_value().as_basic_value_enum(); - let u = u.as_base_value().as_basic_value_enum(); + let l = l.as_abi_value(ctx); + let u = u.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( ctx, - [l, u], + [l.into(), u.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -1953,13 +1953,13 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_matrix_power( ctx, - x1_c.as_base_value().into(), - x2_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + x2_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_det` linalg function @@ -1993,8 +1993,8 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( let out_c = det.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_det( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); @@ -2035,20 +2035,20 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let z_c = z.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_schur( ctx, - x1_c.as_base_value().into(), - t_c.as_base_value().into(), - z_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + t_c.as_abi_value(ctx).into(), + z_c.as_abi_value(ctx).into(), None, ); - let t = t.as_base_value().as_basic_value_enum(); - let z = z.as_base_value().as_basic_value_enum(); + let t = t.as_abi_value(ctx); + let z = z.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( ctx, - [t, z], + [t.into(), z.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2083,18 +2083,18 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let q_c = q.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_hessenberg( ctx, - x1_c.as_base_value().into(), - h_c.as_base_value().into(), - q_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + h_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), None, ); - let h = h.as_base_value().as_basic_value_enum(); - let q = q.as_base_value().as_basic_value_enum(); + let h = h.as_abi_value(ctx); + let q = q.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( ctx, - [h, q], + [h.into(), q.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4da0ef33..7a1d42f4 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1307,7 +1307,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( emit_cont_bb(ctx, list); - Ok(Some(list.as_base_value().into())) + Ok(Some(list.as_abi_value(ctx).into())) } /// Generates LLVM IR for a binary operator expression using the [`Type`] and @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } _ => todo!("Operator not supported"), @@ -1601,7 +1601,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(result) }) .unwrap(); - Ok(Some(result.as_base_value().into())) + Ok(Some(result.as_abi_value(ctx).into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1796,7 +1796,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - mapped_ndarray.as_base_value().into() + mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() })) @@ -1883,7 +1883,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - return Ok(Some(result_ndarray.as_base_value().into())); + return Ok(Some(result_ndarray.as_abi_value(ctx).into())); } } @@ -2493,7 +2493,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ); ctx.builder.build_store(elem_ptr, *v).unwrap(); } - arr_str_ptr.as_base_value().into() + arr_str_ptr.as_abi_value(ctx).into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -2988,7 +2988,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.as_base_value().into() + res_array_ret.as_abi_value(ctx).into() } else { let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { @@ -3050,7 +3050,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .index(generator, ctx, &indices) .split_unsized(generator, ctx) .to_basic_value_enum(); - return Ok(Some(ValueEnum::Dynamic(result))); + return Ok(Some(result.into())); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs index 5e9c0f0b..63a2ab00 100644 --- a/nac3core/src/codegen/irrt/ndarray/array.rs +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -36,7 +36,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato ctx, &name, None, - &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + &[list.as_abi_value(ctx).into(), ndims.into(), shape.base_ptr(ctx, generator).into()], None, None, ); @@ -65,7 +65,7 @@ pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( ctx, &name, None, - &[list.as_base_value().into(), ndarray.as_base_value().into()], + &[list.as_abi_value(ctx).into(), ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index aa792b15..5f291c86 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -93,7 +93,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); @@ -101,7 +101,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("size"), None, ) @@ -118,7 +118,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); @@ -126,7 +126,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("nbytes"), None, ) @@ -143,7 +143,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); @@ -151,7 +151,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("len"), None, ) @@ -167,7 +167,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); @@ -175,7 +175,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("is_c_contiguous"), None, ) @@ -194,7 +194,7 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!(index.get_type(), llvm_usize); @@ -204,7 +204,10 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx, &name, Some(llvm_pi8.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())], + &[ + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), + (llvm_usize.into(), index.into()), + ], Some("pelement"), None, ) @@ -227,7 +230,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!( BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), @@ -241,7 +244,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized &name, Some(llvm_pi8.into()), &[ - (llvm_ndarray.into(), ndarray.as_base_value().into()), + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), @@ -258,7 +261,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); @@ -266,7 +269,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx, &name, None, - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], None, None, ); @@ -288,7 +291,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index a7d40a57..59b0e4cd 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -30,7 +30,7 @@ pub fn call_nac3_ndarray_broadcast_to<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index df5b27de..0d5d920e 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -25,8 +25,8 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( &[ indices.size(ctx, generator).into(), indices.base_ptr(ctx, generator).into(), - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index ad90178c..e4424df0 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -40,8 +40,8 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - (iter.get_type().as_base_type().into(), iter.as_base_value().into()), - (ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()), + (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), + (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], None, @@ -63,7 +63,7 @@ pub fn call_nac3_nditer_has_element<'ctx>( ctx, &name, Some(ctx.ctx.bool_type().into()), - &[iter.as_base_value().into()], + &[iter.as_abi_value(ctx).into()], None, None, ) @@ -77,5 +77,5 @@ pub fn call_nac3_nditer_has_element<'ctx>( pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); - infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); + infer_and_call_function(ctx, &name, None, &[iter.as_abi_value(ctx).into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs index 6d152dd1..331611fa 100644 --- a/nac3core/src/codegen/irrt/ndarray/transpose.rs +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -34,8 +34,8 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { axes.base_ptr(ctx, generator) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 73a28b7a..a188d1c3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -562,7 +562,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_abi_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -572,7 +572,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_abi_type().into() } _ => unreachable!( @@ -626,7 +626,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_abi_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), @@ -800,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new_with_generator(generator, context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3cdd1ef3..2eec88da 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -44,7 +44,7 @@ pub fn gen_ndarray_empty<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.zeros`. @@ -69,7 +69,7 @@ pub fn gen_ndarray_zeros<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.ones`. @@ -94,7 +94,7 @@ pub fn gen_ndarray_ones<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.full`. @@ -127,7 +127,7 @@ pub fn gen_ndarray_full<'ctx>( fill_value_arg, None, ); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } pub fn gen_ndarray_array<'ctx>( @@ -166,7 +166,7 @@ pub fn gen_ndarray_array<'ctx>( .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) .atleast_nd(generator, context, ndims); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.eye`. @@ -225,7 +225,7 @@ pub fn gen_ndarray_eye<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.identity`. @@ -253,7 +253,7 @@ pub fn gen_ndarray_identity<'ctx>( .unwrap(); let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.copy`. @@ -274,7 +274,7 @@ pub fn gen_ndarray_copy<'ctx>( let this = NDArrayType::from_unifier_type(generator, context, this_ty) .map_value(this_arg.into_pointer_value(), None); let ndarray = this.make_copy(generator, context); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.fill`. diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ecc0ba96..15c4654a 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -447,7 +447,7 @@ fn test_classes_list_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); - assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); + assert!(ListType::is_representable(llvm_list.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -458,7 +458,7 @@ fn test_classes_range_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_range = RangeType::new_with_generator(&generator, &ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type(), llvm_usize).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -470,5 +470,5 @@ fn test_classes_ndarray_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); - assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); + assert!(NDArrayType::is_representable(llvm_ndarray.as_abi_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 60015b8c..f99ad5cb 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -305,6 +305,7 @@ impl<'ctx> ListType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; @@ -344,12 +345,16 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 5865d636..abeab5ba 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -38,8 +38,10 @@ pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a - /// [LLVM pointer type][PointerType] for any non-primitive types. + /// The ABI type of which values of this type possess. + type ABI: BasicType<'ctx>; + + /// The LLVM type of which values of this type possess. type Base: BasicType<'ctx>; /// The type of values represented by this type. @@ -118,4 +120,10 @@ pub trait ProxyType<'ctx>: Into { /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; + + /// Returns this proxy as its ABI type, i.e. the expected type representation if a value of this + /// [`ProxyType`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_type(&self) -> Self::ABI; } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 70611127..9630ec15 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -151,7 +151,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, |generator, ctx| { let ndarray = self.construct_numpy_array_from_list_copy_none_impl( @@ -160,7 +160,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() @@ -189,11 +189,11 @@ impl<'ctx> NDArrayType<'ctx> { |_generator, _ctx| Ok(copy), |generator, ctx| { let ndarray = ndarray.make_copy(generator, ctx); // Force copy - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, - |_generator, _ctx| { + |_generator, ctx| { // No need to copy. Return `ndarray` itself. - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index af1a26fa..40847ce2 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -127,6 +127,7 @@ impl<'ctx> ShapeEntryType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ShapeEntryValue<'ctx>; @@ -160,12 +161,16 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 1987ab6d..40311a57 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -189,6 +189,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; @@ -230,12 +231,16 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 8e15c903..ec214ceb 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -158,6 +158,7 @@ impl<'ctx> NDIndexType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; @@ -188,12 +189,16 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 1743fe2b..a79a1f30 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -427,6 +427,7 @@ impl<'ctx> NDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; @@ -458,12 +459,16 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 6246eef2..ba21a7ea 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -185,6 +185,7 @@ impl<'ctx> NDIterType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; @@ -216,12 +217,16 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index 158152bf..b6f15c70 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -72,7 +72,7 @@ impl<'ctx> RangeType<'ctx> { /// Returns the type of all fields of this `range` type. #[must_use] pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + self.as_abi_type().get_element_type().into_array_type().get_element_type().into_int_type() } /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. @@ -120,6 +120,7 @@ impl<'ctx> RangeType<'ctx> { } impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = RangeValue<'ctx>; @@ -163,12 +164,16 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index d05b7f26..29e93233 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -157,6 +157,7 @@ impl<'ctx> TupleType<'ctx> { } impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type ABI = StructType<'ctx>; type Base = StructType<'ctx>; type Value = TupleValue<'ctx>; @@ -182,6 +183,10 @@ impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for StructType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index b7fafefa..e482ed5b 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -174,6 +174,7 @@ impl<'ctx> SliceType<'ctx> { } impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = SliceValue<'ctx>; @@ -229,12 +230,16 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 075f7f64..8b2b6cb2 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -110,7 +110,7 @@ impl<'ctx> ListValue<'ctx> { let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( - ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_abi_type(), "").unwrap(), self.llvm_usize, self.name, ) @@ -118,6 +118,7 @@ impl<'ctx> ListValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ListType<'ctx>; @@ -128,6 +129,10 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 9a246356..90f327e0 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,6 +1,6 @@ use inkwell::{types::IntType, values::BasicValue}; -use super::types::ProxyType; +use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; @@ -16,8 +16,10 @@ pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the - /// [LLVM pointer type][PointerValue]. + /// The ABI type of LLVM values represented by this instance. + type ABI: BasicValue<'ctx>; + + /// The type of LLVM values represented by this instance. type Base: BasicValue<'ctx>; /// The type of this value. @@ -33,4 +35,10 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; + + /// Returns this proxy as its ABI value, i.e. the expected value representation if a value + /// represented by this [`ProxyValue`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> Self::ABI; } diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index acbd2997..883b4613 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -58,6 +58,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ShapeEntryType<'ctx>; @@ -68,6 +69,10 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 65e80258..a23be229 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_base_value(), value, self.name); + self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -49,7 +49,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_base_value(), value, self.name); + self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -61,7 +61,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_base_value(), value, self.name); + self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -70,6 +70,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ContiguousNDArrayType<'ctx>; @@ -84,6 +85,10 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -124,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); + let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3b7b8f10..00846713 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -68,6 +68,7 @@ impl<'ctx> NDIndexValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIndexType<'ctx>; @@ -78,6 +79,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index cba35ad2..e45fe85d 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); + self.shape_field(ctx).set(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); + self.strides_field(ctx).set(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name); + self.data_field(ctx).set(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -462,6 +462,7 @@ impl<'ctx> NDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDArrayType<'ctx>; @@ -477,6 +478,10 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -503,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.shape_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -601,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.strides_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -699,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.data_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -966,7 +971,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { match self { ScalarOrNDArray::Scalar(scalar) => scalar, - ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), + ScalarOrNDArray::NDArray(ndarray) => ndarray.value.into(), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 5479b929..3fdd0a8c 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).get(ctx, self.as_base_value(), self.name); + let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).get(ctx, self.as_base_value(), self.name) + self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. @@ -105,6 +105,7 @@ impl<'ctx> NDIterValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIterType<'ctx>; @@ -115,6 +116,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index b1a5806a..20bdba79 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -34,7 +34,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], var_name.as_str(), ) @@ -49,7 +49,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], var_name.as_str(), ) @@ -64,7 +64,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], var_name.as_str(), ) @@ -137,6 +137,7 @@ impl<'ctx> RangeValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = RangeType<'ctx>; @@ -147,6 +148,10 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 4558f18c..08b2b8be 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -57,6 +57,7 @@ impl<'ctx> TupleValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type ABI = StructValue<'ctx>; type Base = StructValue<'ctx>; type Type = TupleType<'ctx>; @@ -67,6 +68,10 @@ impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for StructValue<'ctx> { diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index df9e4de5..21453f4d 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -150,6 +150,7 @@ impl<'ctx> SliceValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = SliceType<'ctx>; @@ -160,6 +161,10 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1c3b0854..165f64a8 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -664,7 +664,7 @@ impl<'a> BuiltinBuilder<'a> { zelf.store_end(ctx, stop); zelf.store_step(ctx, step); - Ok(Some(zelf.as_base_value().into())) + Ok(Some(zelf.as_abi_value(ctx).into())) }, )))), loc: None, @@ -1320,7 +1320,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(result_tuple.as_base_value().into())) + Ok(Some(result_tuple.as_abi_value(ctx).into())) }), ) } @@ -1356,7 +1356,7 @@ impl<'a> BuiltinBuilder<'a> { .map_value(arg_val.into_pointer_value(), None); let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument - Ok(Some(ndarray.as_base_value().into())) + Ok(Some(ndarray.as_abi_value(ctx).into())) }), ), @@ -1410,7 +1410,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + Ok(Some(new_ndarray.as_abi_value(ctx).as_basic_value_enum())) }), ) } From eec62c3bbb5e63a8bb5c256bba4514fbd8542f58 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 10:53:14 +0800 Subject: [PATCH 77/80] [core] codegen: Refactor StructField getters and setters --- nac3core/src/codegen/types/structure.rs | 39 ++++++++++++++----- nac3core/src/codegen/values/list.rs | 6 +-- .../src/codegen/values/ndarray/broadcast.rs | 4 +- .../src/codegen/values/ndarray/contiguous.rs | 12 +++--- .../src/codegen/values/ndarray/indexing.rs | 8 ++-- nac3core/src/codegen/values/ndarray/mod.rs | 16 ++++---- nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/utils/slice.rs | 30 +++++++------- 8 files changed, 70 insertions(+), 49 deletions(-) diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 0e35c812..d2622b0e 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, types::{BasicTypeEnum, IntType, PointerType, StructType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + values::{AggregateValueEnum, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -203,17 +203,38 @@ where /// Gets the value of this field for a given `obj`. #[must_use] - pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value { - obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap() + pub fn extract_value(&self, ctx: &CodeGenContext<'ctx, '_>, obj: StructValue<'ctx>) -> Value { + Value::try_from( + ctx.builder + .build_extract_value( + obj, + self.index, + &format!("{}.{}", obj.get_name().to_str().unwrap(), self.name), + ) + .unwrap(), + ) + .unwrap() } /// Sets the value of this field for a given `obj`. - pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) { - obj.set_field_at_index(self.index, value); + #[must_use] + pub fn insert_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + obj: StructValue<'ctx>, + value: Value, + ) -> StructValue<'ctx> { + let obj_name = obj.get_name().to_str().unwrap(); + let new_obj_name = if obj_name.chars().all(char::is_numeric) { "" } else { obj_name }; + + ctx.builder + .build_insert_value(obj, value, self.index, new_obj_name) + .map(AggregateValueEnum::into_struct_value) + .unwrap() } - /// Gets the value of this field for a pointer-to-structure. - pub fn get( + /// Loads the value of this field for a pointer-to-structure. + pub fn load( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, @@ -229,8 +250,8 @@ where .unwrap() } - /// Sets the value of this field for a pointer-to-structure. - pub fn set( + /// Stores the value of this field for a pointer-to-structure. + pub fn store( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 8b2b6cb2..cdd1a416 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -45,7 +45,7 @@ impl<'ctx> ListValue<'ctx> { /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - self.items_field(ctx).set(ctx, self.value, data, self.name); + self.items_field(ctx).store(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -91,7 +91,7 @@ impl<'ctx> ListValue<'ctx> { pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { debug_assert_eq!(size.get_type(), ctx.get_size_type()); - self.len_field(ctx).set(ctx, self.value, size, self.name); + self.len_field(ctx).store(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. @@ -100,7 +100,7 @@ impl<'ctx> ListValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> IntValue<'ctx> { - self.len_field(ctx).get(ctx, self.value, name) + self.len_field(ctx).load(ctx, self.value, name) } /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 883b4613..e30bfae2 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -44,7 +44,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { /// Stores the number of dimensions into this value. pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.value, value, self.name); + self.ndims_field().store(ctx, self.value, value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -53,7 +53,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { /// Stores the shape into this value. pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.value, value, self.name); + self.shape_field().store(ctx, self.value, value, self.name); } } diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index a23be229..b8bf0afa 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.ndims_field().store(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -49,11 +49,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.shape_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field().get(ctx, self.value, self.name) + self.shape_field().load(ctx, self.value, self.name) } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -61,11 +61,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.data_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } } @@ -129,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); + let data = self.data_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 00846713..49fdfe17 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -47,11 +47,11 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.type_field().get(ctx, self.value, self.name) + self.type_field().load(ctx, self.value, self.name) } pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.type_field().set(ctx, self.value, value, self.name); + self.type_field().store(ctx, self.value, value, self.name); } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -59,11 +59,11 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.value, value, self.name); + self.data_field().store(ctx, self.value, value, self.name); } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index e45fe85d..38c87e0c 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -94,12 +94,12 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); - self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); + self.itemsize_field(ctx).store(ctx, self.value, itemsize, self.name); } /// Returns the size of each element of this `NDArray` as a value. pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.itemsize_field(ctx).get(ctx, self.value, self.name) + self.itemsize_field(ctx).load(ctx, self.value, self.name) } fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { @@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.value, dims, self.name); + self.shape_field(ctx).store(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.value, strides, self.name); + self.strides_field(ctx).store(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.value, data.into_pointer_value(), self.name); + self.data_field(ctx).store(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -508,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.shape_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( @@ -606,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.strides_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( @@ -704,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.data_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 3fdd0a8c..e4855743 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); + let p = self.element_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) + self.nth_field(ctx).load(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index 21453f4d..f7739357 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -42,7 +42,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_defined_field().get(ctx, self.value, self.name) + self.start_defined_field().load(ctx, self.value, self.name) } fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -50,22 +50,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_field().get(ctx, self.value, self.name) + self.start_field().load(ctx, self.value, self.name) } pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(start) => { - self.start_defined_field().set( + self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.start_field().set(ctx, self.value, start, self.name); + self.start_field().store(ctx, self.value, start, self.name); } - None => self.start_defined_field().set( + None => self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -79,7 +79,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_defined_field().get(ctx, self.value, self.name) + self.stop_defined_field().load(ctx, self.value, self.name) } fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -87,22 +87,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_field().get(ctx, self.value, self.name) + self.stop_field().load(ctx, self.value, self.name) } pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(stop) => { - self.stop_defined_field().set( + self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.stop_field().set(ctx, self.value, stop, self.name); + self.stop_field().store(ctx, self.value, stop, self.name); } - None => self.stop_defined_field().set( + None => self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -116,7 +116,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_defined_field().get(ctx, self.value, self.name) + self.step_defined_field().load(ctx, self.value, self.name) } fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -124,22 +124,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_field().get(ctx, self.value, self.name) + self.step_field().load(ctx, self.value, self.name) } pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(step) => { - self.step_defined_field().set( + self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.step_field().set(ctx, self.value, step, self.name); + self.step_field().store(ctx, self.value, step, self.name); } - None => self.step_defined_field().set( + None => self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), From 68da9b0ecff29a203bcd9539a204e05130e74420 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Jan 2025 23:23:05 +0800 Subject: [PATCH 78/80] [core] codegen: Implement StructProxy on existing proxies --- nac3artiq/src/codegen.rs | 8 +- nac3artiq/src/symbol_resolver.rs | 20 ++-- nac3core/src/codegen/builtin_fns.rs | 59 +++++++----- nac3core/src/codegen/expr.rs | 6 +- nac3core/src/codegen/numpy.rs | 10 +- nac3core/src/codegen/stmt.rs | 4 +- nac3core/src/codegen/types/list.rs | 54 ++++++++--- nac3core/src/codegen/types/ndarray/array.rs | 8 +- .../src/codegen/types/ndarray/broadcast.rs | 53 ++++++++--- .../src/codegen/types/ndarray/contiguous.rs | 57 +++++++++--- .../src/codegen/types/ndarray/indexing.rs | 49 +++++++--- nac3core/src/codegen/types/ndarray/mod.rs | 60 +++++++++--- nac3core/src/codegen/types/ndarray/nditer.rs | 59 +++++++++--- nac3core/src/codegen/types/range.rs | 35 ++++++- nac3core/src/codegen/types/tuple.rs | 37 ++++++-- nac3core/src/codegen/types/utils/slice.rs | 78 ++++++++++++---- nac3core/src/codegen/values/list.rs | 61 +++++++----- .../src/codegen/values/ndarray/broadcast.rs | 39 ++++++-- .../src/codegen/values/ndarray/contiguous.rs | 34 ++++++- .../src/codegen/values/ndarray/indexing.rs | 32 ++++++- nac3core/src/codegen/values/ndarray/mod.rs | 92 +++++++++++-------- nac3core/src/codegen/values/ndarray/nditer.rs | 52 ++++++++--- nac3core/src/codegen/values/ndarray/shape.rs | 4 +- nac3core/src/codegen/values/range.rs | 26 +++++- nac3core/src/codegen/values/tuple.rs | 22 ++++- nac3core/src/codegen/values/utils/slice.rs | 34 ++++++- nac3core/src/toplevel/builtins.rs | 10 +- 27 files changed, 729 insertions(+), 274 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 2cc54387..572acccf 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -476,8 +476,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = - NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); + let ndarray = NDArrayType::new(ctx, dtype, ndims) + .map_pointer_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -1383,7 +1383,7 @@ fn polymorphic_print<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) - .map_value(value.into_pointer_value(), None); + .map_pointer_value(value.into_pointer_value(), None); let num_0 = llvm_usize.const_zero(); @@ -1431,7 +1431,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeType::new(ctx).map_value(value.into_pointer_value(), None); + let val = RangeType::new(ctx).map_pointer_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 06a9400c..f14a8eed 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -16,7 +16,7 @@ use pyo3::{ use super::PrimitivePythonId; use nac3core::{ codegen::{ - types::{ndarray::NDArrayType, ProxyType}, + types::{ndarray::NDArrayType, structure::StructProxyType, ProxyType}, values::ndarray::make_contiguous_strides, CodeGenContext, CodeGenerator, }, @@ -1315,17 +1315,13 @@ impl InnerResolver { .unwrap() }; - let ndarray = llvm_ndarray - .as_abi_type() - .get_element_type() - .into_struct_type() - .const_named_struct(&[ - ndarray_itemsize.into(), - ndarray_ndims.into(), - ndarray_shape.into(), - ndarray_strides.into(), - ndarray_data.into(), - ]); + let ndarray = llvm_ndarray.get_struct_type().const_named_struct(&[ + ndarray_itemsize.into(), + ndarray_ndims.into(), + ndarray_shape.into(), + ndarray_strides.into(), + ndarray_data.into(), + ]); let ndarray_global = ctx.module.add_global( llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 20a89d0a..dfb90824 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -47,14 +47,14 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let range_ty = ctx.primitives.range; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeType::new(ctx).map_value(arg.into_pointer_value(), Some("range")); + let arg = RangeType::new(ctx).map_pointer_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { TypeEnum::TTuple { .. } => { let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_struct_value(), None); + .map_struct_value(arg.into_struct_value(), None); llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } @@ -62,7 +62,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_pointer_value(), None); + .map_pointer_value(arg.into_pointer_value(), None); ctx.builder .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") .unwrap() @@ -72,7 +72,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { let list = ListType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_pointer_value(), None); + .map_pointer_value(arg.into_pointer_value(), None); ctx.builder .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") .unwrap() @@ -126,7 +126,8 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -186,7 +187,8 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -262,7 +264,8 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -327,7 +330,8 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -391,7 +395,8 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -435,7 +440,8 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -474,7 +480,8 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -536,7 +543,8 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -587,7 +595,8 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -637,7 +646,8 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -858,7 +868,8 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, a_ty).map_pointer_value(n, None); let llvm_dtype = ndarray.get_type().element_type(); let zero = llvm_usize.const_zero(); @@ -1638,7 +1649,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1672,7 +1683,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1727,7 +1738,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1785,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1820,7 +1831,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1865,7 +1876,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1974,7 +1985,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -2013,7 +2024,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { @@ -2061,7 +2072,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 7a1d42f4..20d296e4 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1151,7 +1151,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero @@ -1767,7 +1767,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) - .map_value(val.into_pointer_value(), None); + .map_pointer_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -3043,7 +3043,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray_ty = value.custom.unwrap(); let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?; let result = ndarray diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 2eec88da..dfb1b4d8 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -272,7 +272,7 @@ pub fn gen_ndarray_copy<'ctx>( obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let this = NDArrayType::from_unifier_type(generator, context, this_ty) - .map_value(this_arg.into_pointer_value(), None); + .map_pointer_value(this_arg.into_pointer_value(), None); let ndarray = this.make_copy(generator, context); Ok(ndarray.as_abi_value(context)) } @@ -295,7 +295,7 @@ pub fn gen_ndarray_fill<'ctx>( let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let this = NDArrayType::from_unifier_type(generator, context, this_ty) - .map_value(this_arg.into_pointer_value(), None); + .map_pointer_value(this_arg.into_pointer_value(), None); this.fill(generator, context, value_arg); Ok(()) } @@ -316,8 +316,10 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = + NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(n1, None); + let b = + NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_pointer_value(n2, None); // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. assert_eq!(a.get_type().ndims(), 1); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index e8f1d906..0c1b931a 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -440,7 +440,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( // ``` let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) - .map_value(target.into_pointer_value(), None); + .map_pointer_value(target.into_pointer_value(), None); let target = target.index(generator, ctx, &key); let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) @@ -511,7 +511,7 @@ pub fn gen_for( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index f99ad5cb..b4110daa 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,7 +1,7 @@ use inkwell::{ - context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -13,8 +13,9 @@ use crate::{ codegen::{ types::structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, - values::{ListValue, ProxyValue}, + values::ListValue, CodeGenContext, CodeGenerator, }, typecheck::typedef::{iter_type_vars, Type, TypeEnum}, @@ -62,13 +63,6 @@ impl<'ctx> ListType<'ctx> { ListStructFields::new_typed(item, llvm_usize) } - /// See [`ListType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, _ctx: &impl AsContextRef<'ctx>) -> ListStructFields<'ctx> { - Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( @@ -153,9 +147,15 @@ impl<'ctx> ListType<'ctx> { Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) } + /// Creates an [`ListType`] from a [`StructType`]. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); let ctx = ptr_ty.get_context(); @@ -295,9 +295,27 @@ impl<'ctx> ListType<'ctx> { /// Converts an existing value into a [`ListValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -357,6 +375,14 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ListType<'ctx> { + type StructFields = ListStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ListType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 9630ec15..633a0b48 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -167,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) + NDArrayType::new(ctx, dtype, ndims).map_pointer_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. @@ -200,7 +200,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - ndarray.get_type().map_value(ndarray_val, name) + ndarray.get_type().map_pointer_value(ndarray_val, name) } /// Create a new ndarray like @@ -222,7 +222,7 @@ impl<'ctx> NDArrayType<'ctx> { if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { let list = ListType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) } @@ -230,7 +230,7 @@ impl<'ctx> NDArrayType<'ctx> { if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) } diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 40847ce2..fa532b42 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -10,10 +10,10 @@ use nac3core_derive::StructFields; use crate::codegen::{ types::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }, - values::{ndarray::ShapeEntryValue, ProxyValue}, + values::ndarray::ShapeEntryValue, CodeGenContext, CodeGenerator, }; @@ -41,13 +41,6 @@ impl<'ctx> ShapeEntryType<'ctx> { ShapeEntryStructFields::new(ctx, llvm_usize) } - /// See [`ShapeEntryStructFields::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -78,9 +71,15 @@ impl<'ctx> ShapeEntryType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx)) } + /// Creates a [`ShapeEntryType`] from a [`StructType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -117,9 +116,27 @@ impl<'ctx> ShapeEntryType<'ctx> { /// Converts an existing value into a [`ShapeEntryValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -173,6 +190,14 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ShapeEntryType<'ctx> { + type StructFields = ShapeEntryStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ShapeEntryType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 40311a57..1857536a 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -13,10 +13,11 @@ use crate::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{ndarray::ContiguousNDArrayValue, ProxyValue}, + values::ndarray::ContiguousNDArrayValue, CodeGenContext, CodeGenerator, }, toplevel::numpy::unpack_ndarray_var_tys, @@ -67,13 +68,6 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ContiguousNDArrayStructFields::new_typed(item, llvm_usize) } - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> ContiguousNDArrayStructFields<'ctx> { - Self::fields(self.item, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type( @@ -123,9 +117,19 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) } + /// Creates an [`ContiguousNDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), item, llvm_usize) + } + /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, @@ -174,9 +178,28 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { /// Converts an existing value into a [`ContiguousNDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.item, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ContiguousNDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -243,6 +266,14 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type StructFields = ContiguousNDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item, self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ContiguousNDArrayType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index ec214ceb..d00e0fb3 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -10,12 +10,12 @@ use nac3core_derive::StructFields; use crate::codegen::{ types::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }, values::{ ndarray::{NDIndexValue, RustNDIndex}, - ArrayLikeIndexer, ArraySliceValue, ProxyValue, + ArrayLikeIndexer, ArraySliceValue, }, CodeGenContext, CodeGenerator, }; @@ -43,11 +43,6 @@ impl<'ctx> NDIndexType<'ctx> { NDIndexStructFields::new(ctx, llvm_usize) } - #[must_use] - pub fn get_fields(&self) -> NDIndexStructFields<'ctx> { - Self::fields(self.ty.get_context(), self.llvm_usize) - } - #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { let field_tys = @@ -76,7 +71,12 @@ impl<'ctx> NDIndexType<'ctx> { } #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -148,9 +148,26 @@ impl<'ctx> NDIndexType<'ctx> { } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -201,6 +218,14 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDIndexType<'ctx> { + type StructFields = NDIndexStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDIndexType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index a79a1f30..28ea5276 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{BasicValue, IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValue, IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -9,12 +9,12 @@ use itertools::Itertools; use nac3core_derive::StructFields; use super::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }; use crate::{ codegen::{ - values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeMutator}, + values::{ndarray::NDArrayValue, TypedArrayLikeMutator}, {CodeGenContext, CodeGenerator}, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, @@ -71,13 +71,6 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayStructFields::new(ctx, llvm_usize) } - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDArrayStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -183,9 +176,20 @@ impl<'ctx> NDArrayType<'ctx> { Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) } + /// Creates an [`NDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), dtype, ndims, llvm_usize) + } + /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, ndims: u64, @@ -411,9 +415,29 @@ impl<'ctx> NDArrayType<'ctx> { /// Converts an existing value into a [`NDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`NDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -471,6 +495,14 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDArrayType<'ctx> { + type StructFields = NDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDArrayType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index ba21a7ea..aec1a6f5 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -11,7 +11,9 @@ use nac3core_derive::StructFields; use super::ProxyType; use crate::codegen::{ irrt, - types::structure::{check_struct_type_matches_fields, StructField, StructFields}, + types::structure::{ + check_struct_type_matches_fields, StructField, StructFields, StructProxyType, + }, values::{ ndarray::{NDArrayValue, NDIterValue}, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, @@ -50,13 +52,6 @@ impl<'ctx> NDIterType<'ctx> { NDIterStructFields::new(ctx, llvm_usize) } - /// See [`NDIterType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDIter`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -87,9 +82,15 @@ impl<'ctx> NDIterType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx)) } + /// Creates an [`NDIterType`] from a [`StructType`] representing an `NDIter`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -159,7 +160,8 @@ impl<'ctx> NDIterType<'ctx> { let indices = TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); + let nditer = + self.map_pointer_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); @@ -167,9 +169,30 @@ impl<'ctx> NDIterType<'ctx> { } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + parent, + indices, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, parent: NDArrayValue<'ctx>, indices: ArraySliceValue<'ctx>, name: Option<&'ctx str>, @@ -229,6 +252,14 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDIterType<'ctx> { + type StructFields = NDIterStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDIterType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index b6f15c70..e8f6f4d5 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,13 +1,14 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{ArrayValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::{ codegen::{ - values::{ProxyValue, RangeValue}, + values::RangeValue, {CodeGenContext, CodeGenerator}, }, typecheck::typedef::{Type, TypeEnum}, @@ -61,9 +62,15 @@ impl<'ctx> RangeType<'ctx> { Self::new(ctx) } + /// Creates an [`RangeType`] from a [`ArrayType`]. + #[must_use] + pub fn from_array_type(arr_ty: ArrayType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(arr_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`RangeType`] from a [`PointerType`]. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); RangeType { ty: ptr_ty, llvm_usize } @@ -110,9 +117,27 @@ impl<'ctx> RangeType<'ctx> { /// Converts an existing value into a [`RangeValue`]. #[must_use] - pub fn map_value( + pub fn map_array_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: ArrayValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_array_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 29e93233..ea66feb4 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -1,16 +1,13 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, IntType, StructType}, - values::BasicValueEnum, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValueEnum, PointerValue, StructValue}, }; use itertools::Itertools; use super::ProxyType; use crate::{ - codegen::{ - values::{ProxyValue, TupleValue}, - CodeGenContext, CodeGenerator, - }, + codegen::{values::TupleValue, CodeGenContext, CodeGenerator}, typecheck::typedef::{Type, TypeEnum}, }; @@ -77,12 +74,18 @@ impl<'ctx> TupleType<'ctx> { /// Creates an [`TupleType`] from a [`StructType`]. #[must_use] - pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_struct_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok()); TupleType { ty: struct_ty, llvm_usize } } + /// Creates an [`TupleType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize) + } + /// Returns the number of elements present in this [`TupleType`]. #[must_use] pub fn num_elements(&self) -> u32 { @@ -117,7 +120,10 @@ impl<'ctx> TupleType<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name) + self.map_struct_value( + Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), + name, + ) } /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of @@ -147,13 +153,24 @@ impl<'ctx> TupleType<'ctx> { /// Converts an existing value into a [`ListValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + value: StructValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_struct_value(value, self.llvm_usize, name) } + + /// Converts an existing value into a [`TupleValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(ctx, value, self.llvm_usize, name) + } } impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index e482ed5b..e43ac743 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,11 @@ use crate::codegen::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{utils::SliceValue, ProxyValue}, + values::utils::SliceValue, CodeGenContext, CodeGenerator, }; @@ -27,7 +28,7 @@ pub struct SliceType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct SliceFields<'ctx> { +pub struct SliceStructFields<'ctx> { #[value_type(bool_type())] pub start_defined: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize)] @@ -42,14 +43,14 @@ pub struct SliceFields<'ctx> { pub step: StructField<'ctx, IntValue<'ctx>>, } -impl<'ctx> SliceFields<'ctx> { - /// Creates a new instance of [`SliceFields`] with a custom integer type for its range values. +impl<'ctx> SliceStructFields<'ctx> { + /// Creates a new instance of [`SliceStructFields`] with a custom integer type for its range values. #[must_use] pub fn new_sized(ctx: &impl AsContextRef<'ctx>, int_ty: IntType<'ctx>) -> Self { let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) }; let mut counter = FieldIndexCounter::default(); - SliceFields { + SliceStructFields { start_defined: StructField::create(&mut counter, "start_defined", ctx.bool_type()), start: StructField::create(&mut counter, "start", int_ty), stop_defined: StructField::create(&mut counter, "stop_defined", ctx.bool_type()), @@ -61,16 +62,10 @@ impl<'ctx> SliceFields<'ctx> { } impl<'ctx> SliceType<'ctx> { - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> SliceFields<'ctx> { - SliceFields::new_sized(&self.int_ty.get_context(), self.int_ty) - } - /// Creates an LLVM type corresponding to the expected structure of a `Slice`. #[must_use] fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> { - let field_tys = SliceFields::new_sized(&int_ty.get_context(), int_ty) + let field_tys = SliceStructFields::new_sized(&int_ty.get_context(), int_ty) .into_iter() .map(|field| field.1) .collect_vec(); @@ -90,6 +85,16 @@ impl<'ctx> SliceType<'ctx> { Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) } + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + int_ty: IntType<'ctx>, + ) -> Self { + Self::new_impl(ctx, int_ty, generator.get_size_type(ctx)) + } + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. #[must_use] pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { @@ -105,9 +110,19 @@ impl<'ctx> SliceType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) } + /// Creates an [`SliceType`] from a [`StructType`] representing a `slice`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), int_ty, llvm_usize) + } + /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>, @@ -157,11 +172,30 @@ impl<'ctx> SliceType<'ctx> { ) } + /// Converts an existing value into a [`SliceValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.int_ty, + self.llvm_usize, + name, + ) + } + /// Converts an existing value into a [`ContiguousNDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_pointer_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -192,7 +226,7 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { let ctx = ty.get_context(); - let fields = SliceFields::new(ctx, llvm_usize); + let fields = SliceStructFields::new(ctx, llvm_usize); let llvm_ty = ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { @@ -242,6 +276,14 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for SliceType<'ctx> { + type StructFields = SliceStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + SliceStructFields::new_sized(&self.ty.get_context(), self.int_ty) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: SliceType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index cdd1a416..453065b8 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -1,14 +1,18 @@ use inkwell::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::{structure::StructField, ListType, ProxyType}, + types::{ + structure::{StructField, StructProxyType}, + ListType, ProxyType, + }, {CodeGenContext, CodeGenerator}, }; @@ -21,6 +25,26 @@ pub struct ListValue<'ctx> { } impl<'ctx> ListValue<'ctx> { + /// Creates an [`ListValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -33,19 +57,13 @@ impl<'ctx> ListValue<'ctx> { ListValue { value: ptr, llvm_usize, name } } - fn items_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(&ctx.ctx).items - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.items_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn items_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().items } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - self.items_field(ctx).store(ctx, self.value, data, self.name); + self.items_field().store(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -83,15 +101,15 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } - fn len_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(&ctx.ctx).len + fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().len } /// Stores the `size` of this `list` into this instance. pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { debug_assert_eq!(size.get_type(), ctx.get_size_type()); - self.len_field(ctx).store(ctx, self.value, size, self.name); + self.len_field().store(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. @@ -100,7 +118,7 @@ impl<'ctx> ListValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> IntValue<'ctx> { - self.len_field(ctx).load(ctx, self.value, name) + self.len_field().load(ctx, self.value, name) } /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. @@ -123,7 +141,7 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { type Type = ListType<'ctx>; fn get_type(&self) -> Self::Type { - ListType::from_type(self.as_base_value().get_type(), self.llvm_usize) + ListType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -135,6 +153,8 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ListValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ListValue<'ctx>) -> Self { value.as_base_value() @@ -159,12 +179,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() + self.0.items_field().load(ctx, self.0.value, self.0.name) } fn size( diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index e30bfae2..4935a364 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, }; use itertools::Itertools; @@ -8,12 +8,13 @@ use crate::codegen::{ irrt, types::{ ndarray::{NDArrayType, ShapeEntryType}, - structure::StructField, + structure::{StructField, StructProxyType}, ProxyType, }, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::NDArrayValue, structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -26,6 +27,26 @@ pub struct ShapeEntryValue<'ctx> { } impl<'ctx> ShapeEntryValue<'ctx> { + /// Creates an [`ShapeEntryValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -39,7 +60,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(self.value.get_type().get_context()).ndims + self.get_type().get_fields().ndims } /// Stores the number of dimensions into this value. @@ -48,7 +69,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(self.value.get_type().get_context()).shape + self.get_type().get_fields().shape } /// Stores the shape into this value. @@ -63,7 +84,7 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { type Type = ShapeEntryType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -75,6 +96,8 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ShapeEntryValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ShapeEntryValue<'ctx>) -> Self { value.as_base_value() @@ -163,7 +186,7 @@ fn broadcast_shapes<'ctx, G, Shape>( None, ) }; - let shape_entry = llvm_shape_ty.map_value(pshape_entry, None); + let shape_entry = llvm_shape_ty.map_pointer_value(pshape_entry, None); let in_ndims = llvm_usize.const_int(*in_ndims, false); shape_entry.store_ndims(ctx, in_ndims); diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index b8bf0afa..9dca06a4 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -1,16 +1,17 @@ use inkwell::{ types::{BasicType, BasicTypeEnum, IntType}, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{ArrayLikeValue, NDArrayValue, ProxyValue}; +use super::NDArrayValue; use crate::codegen::{ stmt::gen_if_callback, types::{ ndarray::{ContiguousNDArrayType, NDArrayType}, - structure::StructField, + structure::{StructField, StructProxyType}, }, + values::{structure::StructProxyValue, ArrayLikeValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -23,6 +24,27 @@ pub struct ContiguousNDArrayValue<'ctx> { } impl<'ctx> ContiguousNDArrayValue<'ctx> { + /// Creates an [`ContiguousNDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, llvm_usize, name) + } + /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -75,7 +97,7 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { type Type = ContiguousNDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - >::Type::from_type( + >::Type::from_pointer_type( self.as_base_value().get_type(), self.item, self.llvm_usize, @@ -91,6 +113,8 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ContiguousNDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -129,7 +153,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); + let data = self.data_field().load(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 49fdfe17..6ed0ed0b 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,12 @@ use crate::{ irrt, types::{ ndarray::{NDArrayType, NDIndexType}, - structure::StructField, + structure::{StructField, StructProxyType}, utils::SliceType, }, - values::{ndarray::NDArrayValue, utils::RustSlice, ProxyValue}, + values::{ + ndarray::NDArrayValue, structure::StructProxyValue, utils::RustSlice, ProxyValue, + }, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -30,6 +32,26 @@ pub struct NDIndexValue<'ctx> { } impl<'ctx> NDIndexValue<'ctx> { + /// Creates an [`NDIndexValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`NDIndexValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -73,7 +95,7 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { type Type = NDIndexType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -85,6 +107,8 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDIndexValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIndexValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 38c87e0c..dcb69473 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -2,14 +2,14 @@ use std::iter::repeat_n; use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor, - TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::{ @@ -18,7 +18,11 @@ use crate::{ llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, stmt::gen_for_callback_incrementing, type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, + types::{ + ndarray::NDArrayType, + structure::{StructField, StructProxyType}, + TupleType, + }, CodeGenContext, CodeGenerator, }, typecheck::typedef::{Type, TypeEnum}, @@ -49,6 +53,28 @@ pub struct NDArrayValue<'ctx> { } impl<'ctx> NDArrayValue<'ctx> { + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, ndims, llvm_usize, name) + } + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -63,52 +89,45 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } - fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).ndims - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().ndims } /// Stores the number of dimensions `ndims` into this instance. pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); + self.ndims_field().store(ctx, self.value, ndims, self.name); } /// Returns the number of dimensions of this `NDArray` as a value. pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + self.ndims_field().load(ctx, self.value, self.name) } - fn itemsize_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).itemsize + fn itemsize_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().itemsize } /// Stores the size of each element `itemsize` into this instance. pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); - self.itemsize_field(ctx).store(ctx, self.value, itemsize, self.name); + self.itemsize_field().store(ctx, self.value, itemsize, self.name); } /// Returns the size of each element of this `NDArray` as a value. pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.itemsize_field(ctx).load(ctx, self.value, self.name) + self.itemsize_field().load(ctx, self.value, self.name) } - fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).shape + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().shape } /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).store(ctx, self.value, dims, self.name); + self.shape_field().store(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -127,16 +146,13 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayShapeProxy(self) } - fn strides_field( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).strides + fn strides_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().strides } /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).store(ctx, self.value, strides, self.name); + self.strides_field().store(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -155,14 +171,14 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayStridesProxy(self) } - fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).data + fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().data } /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name) + self.data_field().ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. @@ -171,7 +187,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).store(ctx, self.value, data.into_pointer_value(), self.name); + self.data_field().store(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -467,7 +483,7 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - NDArrayType::from_type( + NDArrayType::from_pointer_type( self.as_base_value().get_type(), self.dtype, self.ndims, @@ -484,6 +500,8 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -508,7 +526,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.shape_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -606,7 +624,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.strides_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -704,7 +722,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.data_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -958,7 +976,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); ScalarOrNDArray::NDArray(ndarray) } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index e4855743..c1bf7513 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -1,15 +1,18 @@ use inkwell::{ types::{BasicType, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{NDArrayValue, ProxyValue}; +use super::NDArrayValue; use crate::codegen::{ irrt, stmt::{gen_for_callback, BreakContinueHooks}, - types::{ndarray::NDIterType, structure::StructField}, - values::{ArraySliceValue, TypedArrayLikeAdapter}, + types::{ + ndarray::NDIterType, + structure::{StructField, StructProxyType}, + }, + values::{structure::StructProxyValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter}, CodeGenContext, CodeGenerator, }; @@ -23,6 +26,28 @@ pub struct NDIterValue<'ctx> { } impl<'ctx> NDIterValue<'ctx> { + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, parent, indices, llvm_usize, name) + } + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -56,11 +81,8 @@ impl<'ctx> NDIterValue<'ctx> { irrt::ndarray::call_nac3_nditer_next(ctx, *self); } - fn element_field( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).element + fn element_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().element } /// Get pointer to the current element. @@ -68,7 +90,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); + let p = self.element_field().load(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -81,14 +103,14 @@ impl<'ctx> NDIterValue<'ctx> { ctx.builder.build_load(p, "value").unwrap() } - fn nth_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).nth + fn nth_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().nth } /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).load(ctx, self.as_abi_value(ctx), self.name) + self.nth_field().load(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. @@ -110,7 +132,7 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { type Type = NDIterType<'ctx>; fn get_type(&self) -> Self::Type { - NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize) + NDIterType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -122,6 +144,8 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDIterValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIterValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index 3ac2795d..b3331b6f 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -42,7 +42,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) - .map_value(input_seq.into_pointer_value(), None); + .map_pointer_value(input_seq.into_pointer_value(), None); let len = input_seq.load_size(ctx, None); // TODO: Find a way to remove this mid-BB allocation @@ -86,7 +86,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) - .map_value(input_seq.into_struct_value(), None); + .map_struct_value(input_seq.into_struct_value(), None); let len = input_seq.get_type().num_elements(); diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 20bdba79..67e623a1 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -1,10 +1,10 @@ use inkwell::{ types::IntType, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{ArrayValue, BasicValueEnum, IntValue, PointerValue}, }; use super::ProxyValue; -use crate::codegen::{types::RangeType, CodeGenContext}; +use crate::codegen::{types::RangeType, CodeGenContext, CodeGenerator}; /// Proxy type for accessing a `range` value in LLVM. #[derive(Copy, Clone)] @@ -15,6 +15,26 @@ pub struct RangeValue<'ctx> { } impl<'ctx> RangeValue<'ctx> { + /// Creates an [`RangeValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_array_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: ArrayValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -142,7 +162,7 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type(), self.llvm_usize) + RangeType::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 08b2b8be..320e2190 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{BasicValue, BasicValueEnum, StructValue}, + values::{BasicValue, BasicValueEnum, PointerValue, StructValue}, }; use super::ProxyValue; @@ -26,6 +26,24 @@ impl<'ctx> TupleValue<'ctx> { Self { value, llvm_usize, name } } + /// Creates an [`TupleValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + Self::from_struct_value( + ctx.builder + .build_load(ptr, name.unwrap_or_default()) + .map(BasicValueEnum::into_struct_value) + .unwrap(), + llvm_usize, + name, + ) + } + /// Stores a value into the tuple element at the given `index`. pub fn store_element( &mut self, @@ -62,7 +80,7 @@ impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { type Type = TupleType<'ctx>; fn get_type(&self) -> Self::Type { - TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize) + TupleType::from_struct_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index f7739357..549e556d 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -1,14 +1,17 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, }; use nac3parser::ast::Expr; use crate::{ codegen::{ - types::{structure::StructField, utils::SliceType}, - values::ProxyValue, + types::{ + structure::{StructField, StructProxyType}, + utils::SliceType, + }, + values::{structure::StructProxyValue, ProxyValue}, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -24,6 +27,27 @@ pub struct SliceValue<'ctx> { } impl<'ctx> SliceValue<'ctx> { + /// Creates an [`SliceValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, int_ty, llvm_usize, name) + } + /// Creates an [`SliceValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -155,7 +179,7 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { type Type = SliceType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.int_ty, self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.int_ty, self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -167,6 +191,8 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for SliceValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: SliceValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 165f64a8..eff614e5 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -577,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeType::new(ctx).map_value(zelf, Some("range")); + let zelf = RangeType::new(ctx).map_pointer_value(zelf, Some("range")); let mut start = None; let mut stop = None; @@ -1280,7 +1280,7 @@ impl<'a> BuiltinBuilder<'a> { let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let size = ctx .builder @@ -1312,7 +1312,7 @@ impl<'a> BuiltinBuilder<'a> { args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let result_tuple = match prim { PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), @@ -1353,7 +1353,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg_val.into_pointer_value(), None); + .map_pointer_value(arg_val.into_pointer_value(), None); let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument Ok(Some(ndarray.as_abi_value(ctx).into())) @@ -1391,7 +1391,7 @@ impl<'a> BuiltinBuilder<'a> { args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray_val.into_pointer_value(), None); + .map_pointer_value(ndarray_val.into_pointer_value(), None); let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); From d394b24304ed06bd758fdf684f943131dd84e282 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 13:10:13 +0800 Subject: [PATCH 79/80] [meta] flake: Add LLVM bintools to artiq-{instrumented,pgo} --- flake.nix | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index a48ff69b..51551c77 100644 --- a/flake.nix +++ b/flake.nix @@ -85,7 +85,7 @@ name = "nac3artiq-instrumented"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-instrumented ]; buildInputs = [ pkgs.python3 llvm-nac3-instrumented ]; cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ]; doCheck = false; @@ -148,7 +148,7 @@ name = "nac3artiq-pgo"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-pgo ]; buildInputs = [ pkgs.python3 llvm-nac3-pgo ]; cargoBuildFlags = [ "--package" "nac3artiq" ]; cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ]; From c32c68b0b0823a21a4528262321dc4c1ebc4fde4 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 5 Feb 2025 15:42:23 +0800 Subject: [PATCH 80/80] flake: update dependencies --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 3e4af709..7e36e537 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1736798957, - "narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=", + "lastModified": 1738680400, + "narHash": "sha256-ooLh+XW8jfa+91F1nhf9OF7qhuA/y1ChLx6lXDNeY5U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3", + "rev": "799ba5bffed04ced7067a91798353d360788b30d", "type": "github" }, "original": {