From bb1687f8a46ad22673cb2fa034bcfb619bceb96a Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 14 Aug 2024 10:19:09 +0800 Subject: [PATCH] WIP: core/ndstrides: minor cleanup --- nac3core/src/codegen/numpy_new.rs | 5 +- .../src/codegen/object/ndarray/broadcast.rs | 16 +++--- .../src/codegen/object/ndarray/functions.rs | 4 +- nac3core/src/codegen/object/ndarray/mod.rs | 2 +- nac3core/src/toplevel/builtins.rs | 7 +-- nac3core/src/toplevel/helper.rs | 20 +++++++ nac3core/src/toplevel/numpy.rs | 57 ------------------- nac3core/src/typecheck/magic_methods.rs | 15 ++--- nac3core/src/typecheck/type_inferencer/mod.rs | 5 +- 9 files changed, 41 insertions(+), 90 deletions(-) diff --git a/nac3core/src/codegen/numpy_new.rs b/nac3core/src/codegen/numpy_new.rs index 8a87af8b..5d259d98 100644 --- a/nac3core/src/codegen/numpy_new.rs +++ b/nac3core/src/codegen/numpy_new.rs @@ -6,10 +6,7 @@ use nac3parser::ast::StrRef; use crate::{ codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject}, symbol_resolver::ValueEnum, - toplevel::{ - numpy::{extract_ndims, unpack_ndarray_var_tys}, - DefinitionId, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, typecheck::typedef::{FunSignature, Type}, }; diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs index d75041e4..68098d7b 100644 --- a/nac3core/src/codegen/object/ndarray/broadcast.rs +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -1,12 +1,9 @@ use itertools::Itertools; -use crate::{ - codegen::{ - irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, - model::*, - CodeGenContext, CodeGenerator, - }, - toplevel::numpy::get_broadcast_all_ndims, +use crate::codegen::{ + irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, + model::*, + CodeGenContext, CodeGenerator, }; use super::NDArrayObject; @@ -119,8 +116,9 @@ impl<'ctx> NDArrayObject<'ctx> { let sizet_model = IntModel(SizeT); - let broadcast_ndims_int = - get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims)); + // Infer the broadcast output ndims. + let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); + let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int); let broadcast_shape = sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape"); diff --git a/nac3core/src/codegen/object/ndarray/functions.rs b/nac3core/src/codegen/object/ndarray/functions.rs index 285b8fbd..69d35b21 100644 --- a/nac3core/src/codegen/object/ndarray/functions.rs +++ b/nac3core/src/codegen/object/ndarray/functions.rs @@ -252,6 +252,7 @@ impl<'ctx> ScalarObject<'ctx> { ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder .build_select(val_gez, to_uint64, to_int64, "conv") .unwrap() @@ -328,8 +329,7 @@ impl<'ctx> ScalarObject<'ctx> { /// Invoke NAC3's builtin `np_round()`. /// - /// NOTE: `np.round()` has different behaviors than `round()` in terms of their result - /// on "tie" cases and return type. + /// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type. #[must_use] pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 5c4e94aa..fc0bfbcd 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -21,7 +21,7 @@ use crate::{ structure::NDArray, CodeGenContext, CodeGenerator, }, - toplevel::numpy::{extract_ndims, unpack_ndarray_var_tys}, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; use indexing::RustNDIndex; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ed245835..f1d1fdf3 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,6 @@ use std::iter::once; -use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; +use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -25,10 +25,7 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::{ - helper::PrimDef, - numpy::{create_ndims, make_ndarray_ty}, - }, + toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, }; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index ece60120..5df95a6c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1018,3 +1018,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { _ => 0, } } + +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +} diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index b4851d6f..2b4ea43b 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -84,60 +84,3 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() } - -/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. -/// The `ndims` must only contain 1 value. -#[must_use] -pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { - let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); - let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { - panic!("ndims_ty should be a TLiteral"); - }; - - assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); - - let ndims = values[0].clone(); - u64::try_from(ndims).unwrap() -} - -/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. -pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { - unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) -} - -/// Return the ndims after broadcasting ndarrays of different ndims. -/// -/// Panics if the input list is empty. -pub fn get_broadcast_all_ndims(ndims: I) -> u64 -where - I: IntoIterator, -{ - ndims.into_iter().max().unwrap() -} - -pub fn split_scalar_or_ndarray_type( - unifier: &mut Unifier, - primitives: &PrimitiveStore, - ty: Type, -) -> Either { - match &*unifier.get_ty(ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == primitives.ndarray.obj_id(unifier).unwrap() => { - Either::Right(unpack_ndarray_var_tys(unifier, ty)) - } - _ => Either::Left(ty), - } -} - -pub fn split_as_ndarray_type( - unifier: &mut Unifier, - primitives: &PrimitiveStore, - ty: Type, -) -> (Type, Type) { - match split_scalar_or_ndarray_type(unifier, primitives, ty) { - Either::Left(dtype) => { - let ndims = unifier.get_fresh_literal(vec![SymbolValue::U64(0)], None); - (dtype, ndims) - } - Either::Right((dtype, ndims)) => (dtype, ndims), - } -} diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index eb740edb..d4cb31d6 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,9 +1,6 @@ use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::{ - extract_ndims, make_ndarray_ty, split_as_ndarray_type, split_scalar_or_ndarray_type, - unpack_ndarray_var_tys, -}; +use crate::toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}; +use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, @@ -523,11 +520,11 @@ pub fn typeof_binop( } Operator::MatMult => { - let (lhs_dtype, lhs_ndims) = split_as_ndarray_type(unifier, primitives, lhs); - let (rhs_dtype, rhs_ndims) = split_as_ndarray_type(unifier, primitives, rhs); + let lhs_dtype = arraylike_flatten_element_type(unifier, lhs); + let rhs_dtype = arraylike_flatten_element_type(unifier, rhs); - let lhs_ndims = extract_ndims(unifier, lhs_ndims); - let rhs_ndims = extract_ndims(unifier, rhs_ndims); + let lhs_ndims = arraylike_get_ndims(unifier, lhs); + let rhs_ndims = arraylike_get_ndims(unifier, rhs); if !(unifier.unioned(lhs_dtype, primitives.float) && unifier.unioned(rhs_dtype, primitives.float)) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7187740e..bc210c27 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -17,7 +17,7 @@ use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, - numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, TopLevelDef, }, typecheck::typedef::Mapping, @@ -1554,8 +1554,7 @@ impl<'a> Inferencer<'a> { let ndarray = self.fold_expr(args.remove(0))?; - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray.custom.unwrap()); - let ndims = extract_ndims(self.unifier, ndims); + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); // Create a tuple of size `ndims` full of int32 // TODO: Make it usize