forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: minor cleanup

This commit is contained in:
lyken 2024-08-14 10:19:09 +08:00
parent 1d7184708f
commit bb1687f8a4
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
9 changed files with 41 additions and 90 deletions

View File

@ -6,10 +6,7 @@ use nac3parser::ast::StrRef;
use crate::{ use crate::{
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject}, codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
numpy::{extract_ndims, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
}; };

View File

@ -1,12 +1,9 @@
use itertools::Itertools; use itertools::Itertools;
use crate::{ use crate::codegen::{
codegen::{ irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, model::*,
model::*, CodeGenContext, CodeGenerator,
CodeGenContext, CodeGenerator,
},
toplevel::numpy::get_broadcast_all_ndims,
}; };
use super::NDArrayObject; use super::NDArrayObject;
@ -119,8 +116,9 @@ impl<'ctx> NDArrayObject<'ctx> {
let sizet_model = IntModel(SizeT); let sizet_model = IntModel(SizeT);
let broadcast_ndims_int = // Infer the broadcast output ndims.
get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims)); let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int); let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
let broadcast_shape = let broadcast_shape =
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape"); sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");

View File

@ -252,6 +252,7 @@ impl<'ctx> ScalarObject<'ctx> {
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 = let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder ctx.builder
.build_select(val_gez, to_uint64, to_int64, "conv") .build_select(val_gez, to_uint64, to_int64, "conv")
.unwrap() .unwrap()
@ -328,8 +329,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Invoke NAC3's builtin `np_round()`. /// Invoke NAC3's builtin `np_round()`.
/// ///
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result /// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
/// on "tie" cases and return type.
#[must_use] #[must_use]
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {

View File

@ -21,7 +21,7 @@ use crate::{
structure::NDArray, structure::NDArray,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
toplevel::numpy::{extract_ndims, unpack_ndarray_var_tys}, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type, typecheck::typedef::Type,
}; };
use indexing::RustNDIndex; use indexing::RustNDIndex;

View File

@ -1,6 +1,6 @@
use std::iter::once; use std::iter::once;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap; use indexmap::IndexMap;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -25,10 +25,7 @@ use crate::{
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{ toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
helper::PrimDef,
numpy::{create_ndims, make_ndarray_ty},
},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
}; };

View File

@ -1018,3 +1018,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
_ => 0, _ => 0,
} }
} }
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}

View File

@ -84,60 +84,3 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
} }
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}
/// Return the ndims after broadcasting ndarrays of different ndims.
///
/// Panics if the input list is empty.
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
where
I: IntoIterator<Item = u64>,
{
ndims.into_iter().max().unwrap()
}
pub fn split_scalar_or_ndarray_type(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ty: Type,
) -> Either<Type, (Type, Type)> {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == primitives.ndarray.obj_id(unifier).unwrap() => {
Either::Right(unpack_ndarray_var_tys(unifier, ty))
}
_ => Either::Left(ty),
}
}
pub fn split_as_ndarray_type(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ty: Type,
) -> (Type, Type) {
match split_scalar_or_ndarray_type(unifier, primitives, ty) {
Either::Left(dtype) => {
let ndims = unifier.get_fresh_literal(vec![SymbolValue::U64(0)], None);
(dtype, ndims)
}
Either::Right((dtype, ndims)) => (dtype, ndims),
}
}

View File

@ -1,9 +1,6 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef};
use crate::toplevel::numpy::{ use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
extract_ndims, make_ndarray_ty, split_as_ndarray_type, split_scalar_or_ndarray_type,
unpack_ndarray_var_tys,
};
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@ -523,11 +520,11 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
let (lhs_dtype, lhs_ndims) = split_as_ndarray_type(unifier, primitives, lhs); let lhs_dtype = arraylike_flatten_element_type(unifier, lhs);
let (rhs_dtype, rhs_ndims) = split_as_ndarray_type(unifier, primitives, rhs); let rhs_dtype = arraylike_flatten_element_type(unifier, rhs);
let lhs_ndims = extract_ndims(unifier, lhs_ndims); let lhs_ndims = arraylike_get_ndims(unifier, lhs);
let rhs_ndims = extract_ndims(unifier, rhs_ndims); let rhs_ndims = arraylike_get_ndims(unifier, rhs);
if !(unifier.unioned(lhs_dtype, primitives.float) if !(unifier.unioned(lhs_dtype, primitives.float)
&& unifier.unioned(rhs_dtype, primitives.float)) && unifier.unioned(rhs_dtype, primitives.float))

View File

@ -17,7 +17,7 @@ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext, TopLevelDef, TopLevelContext, TopLevelDef,
}, },
typecheck::typedef::Mapping, typecheck::typedef::Mapping,
@ -1554,8 +1554,7 @@ impl<'a> Inferencer<'a> {
let ndarray = self.fold_expr(args.remove(0))?; let ndarray = self.fold_expr(args.remove(0))?;
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray.custom.unwrap()); let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
let ndims = extract_ndims(self.unifier, ndims);
// Create a tuple of size `ndims` full of int32 // Create a tuple of size `ndims` full of int32
// TODO: Make it usize // TODO: Make it usize