forked from M-Labs/nac3
WIP: core/ndstrides: minor cleanup
This commit is contained in:
parent
1d7184708f
commit
bb1687f8a4
@ -6,10 +6,7 @@ use nac3parser::ast::StrRef;
|
||||
use crate::{
|
||||
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
||||
DefinitionId,
|
||||
},
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
|
@ -1,12 +1,9 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::numpy::get_broadcast_all_ndims,
|
||||
use crate::codegen::{
|
||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::NDArrayObject;
|
||||
@ -119,8 +116,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let broadcast_ndims_int =
|
||||
get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims));
|
||||
// Infer the broadcast output ndims.
|
||||
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
|
||||
|
||||
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
|
||||
let broadcast_shape =
|
||||
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");
|
||||
|
@ -252,6 +252,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
||||
.unwrap()
|
||||
@ -328,8 +329,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
|
||||
/// Invoke NAC3's builtin `np_round()`.
|
||||
///
|
||||
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
|
||||
/// on "tie" cases and return type.
|
||||
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
||||
#[must_use]
|
||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
|
@ -21,7 +21,7 @@ use crate::{
|
||||
structure::NDArray,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::numpy::{extract_ndims, unpack_ndarray_var_tys},
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
use indexing::RustNDIndex;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use std::iter::once;
|
||||
|
||||
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||
use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||
use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
@ -25,10 +25,7 @@ use crate::{
|
||||
stmt::exn_constructor,
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{create_ndims, make_ndarray_ty},
|
||||
},
|
||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||
};
|
||||
|
||||
|
@ -1018,3 +1018,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||
/// The `ndims` must only contain 1 value.
|
||||
#[must_use]
|
||||
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||
panic!("ndims_ty should be a TLiteral");
|
||||
};
|
||||
|
||||
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||
|
||||
let ndims = values[0].clone();
|
||||
u64::try_from(ndims).unwrap()
|
||||
}
|
||||
|
||||
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||
}
|
||||
|
@ -84,60 +84,3 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI
|
||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
||||
}
|
||||
|
||||
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||
/// The `ndims` must only contain 1 value.
|
||||
#[must_use]
|
||||
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||
panic!("ndims_ty should be a TLiteral");
|
||||
};
|
||||
|
||||
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||
|
||||
let ndims = values[0].clone();
|
||||
u64::try_from(ndims).unwrap()
|
||||
}
|
||||
|
||||
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||
}
|
||||
|
||||
/// Return the ndims after broadcasting ndarrays of different ndims.
|
||||
///
|
||||
/// Panics if the input list is empty.
|
||||
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
|
||||
where
|
||||
I: IntoIterator<Item = u64>,
|
||||
{
|
||||
ndims.into_iter().max().unwrap()
|
||||
}
|
||||
|
||||
pub fn split_scalar_or_ndarray_type(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
ty: Type,
|
||||
) -> Either<Type, (Type, Type)> {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == primitives.ndarray.obj_id(unifier).unwrap() => {
|
||||
Either::Right(unpack_ndarray_var_tys(unifier, ty))
|
||||
}
|
||||
_ => Either::Left(ty),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split_as_ndarray_type(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
ty: Type,
|
||||
) -> (Type, Type) {
|
||||
match split_scalar_or_ndarray_type(unifier, primitives, ty) {
|
||||
Either::Left(dtype) => {
|
||||
let ndims = unifier.get_fresh_literal(vec![SymbolValue::U64(0)], None);
|
||||
(dtype, ndims)
|
||||
}
|
||||
Either::Right((dtype, ndims)) => (dtype, ndims),
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,6 @@
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::{
|
||||
extract_ndims, make_ndarray_ty, split_as_ndarray_type, split_scalar_or_ndarray_type,
|
||||
unpack_ndarray_var_tys,
|
||||
};
|
||||
use crate::toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef};
|
||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
@ -523,11 +520,11 @@ pub fn typeof_binop(
|
||||
}
|
||||
|
||||
Operator::MatMult => {
|
||||
let (lhs_dtype, lhs_ndims) = split_as_ndarray_type(unifier, primitives, lhs);
|
||||
let (rhs_dtype, rhs_ndims) = split_as_ndarray_type(unifier, primitives, rhs);
|
||||
let lhs_dtype = arraylike_flatten_element_type(unifier, lhs);
|
||||
let rhs_dtype = arraylike_flatten_element_type(unifier, rhs);
|
||||
|
||||
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
||||
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
||||
let lhs_ndims = arraylike_get_ndims(unifier, lhs);
|
||||
let rhs_ndims = arraylike_get_ndims(unifier, rhs);
|
||||
|
||||
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||
|
@ -17,7 +17,7 @@ use crate::{
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
toplevel::{
|
||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||
numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelContext, TopLevelDef,
|
||||
},
|
||||
typecheck::typedef::Mapping,
|
||||
@ -1554,8 +1554,7 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let ndarray = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray.custom.unwrap());
|
||||
let ndims = extract_ndims(self.unifier, ndims);
|
||||
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
|
||||
|
||||
// Create a tuple of size `ndims` full of int32
|
||||
// TODO: Make it usize
|
||||
|
Loading…
Reference in New Issue
Block a user