forked from M-Labs/nac3
WIP: core/ndstrides: minor cleanup
This commit is contained in:
parent
1d7184708f
commit
bb1687f8a4
|
@ -6,10 +6,7 @@ use nac3parser::ast::StrRef;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
|
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||||
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
|
||||||
DefinitionId,
|
|
||||||
},
|
|
||||||
typecheck::typedef::{FunSignature, Type},
|
typecheck::typedef::{FunSignature, Type},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::{
|
use crate::codegen::{
|
||||||
codegen::{
|
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
||||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
model::*,
|
||||||
model::*,
|
CodeGenContext, CodeGenerator,
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
toplevel::numpy::get_broadcast_all_ndims,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::NDArrayObject;
|
use super::NDArrayObject;
|
||||||
|
@ -119,8 +116,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
let broadcast_ndims_int =
|
// Infer the broadcast output ndims.
|
||||||
get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims));
|
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
|
||||||
|
|
||||||
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
|
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
|
||||||
let broadcast_shape =
|
let broadcast_shape =
|
||||||
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");
|
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");
|
||||||
|
|
|
@ -252,6 +252,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||||
let to_uint64 =
|
let to_uint64 =
|
||||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
.build_select(val_gez, to_uint64, to_int64, "conv")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -328,8 +329,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `np_round()`.
|
/// Invoke NAC3's builtin `np_round()`.
|
||||||
///
|
///
|
||||||
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
|
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
||||||
/// on "tie" cases and return type.
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||||
|
|
|
@ -21,7 +21,7 @@ use crate::{
|
||||||
structure::NDArray,
|
structure::NDArray,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
toplevel::numpy::{extract_ndims, unpack_ndarray_var_tys},
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
use indexing::RustNDIndex;
|
use indexing::RustNDIndex;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
@ -25,10 +25,7 @@ use crate::{
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::{
|
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
||||||
helper::PrimDef,
|
|
||||||
numpy::{create_ndims, make_ndarray_ty},
|
|
||||||
},
|
|
||||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1018,3 +1018,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||||
_ => 0,
|
_ => 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||||
|
/// The `ndims` must only contain 1 value.
|
||||||
|
#[must_use]
|
||||||
|
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||||
|
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||||
|
panic!("ndims_ty should be a TLiteral");
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||||
|
|
||||||
|
let ndims = values[0].clone();
|
||||||
|
u64::try_from(ndims).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||||
|
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||||
|
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||||
|
}
|
||||||
|
|
|
@ -84,60 +84,3 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI
|
||||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
||||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
|
||||||
/// The `ndims` must only contain 1 value.
|
|
||||||
#[must_use]
|
|
||||||
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
|
||||||
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
|
||||||
panic!("ndims_ty should be a TLiteral");
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
|
||||||
|
|
||||||
let ndims = values[0].clone();
|
|
||||||
u64::try_from(ndims).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
|
||||||
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
|
||||||
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the ndims after broadcasting ndarrays of different ndims.
|
|
||||||
///
|
|
||||||
/// Panics if the input list is empty.
|
|
||||||
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
|
|
||||||
where
|
|
||||||
I: IntoIterator<Item = u64>,
|
|
||||||
{
|
|
||||||
ndims.into_iter().max().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn split_scalar_or_ndarray_type(
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
ty: Type,
|
|
||||||
) -> Either<Type, (Type, Type)> {
|
|
||||||
match &*unifier.get_ty(ty) {
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == primitives.ndarray.obj_id(unifier).unwrap() => {
|
|
||||||
Either::Right(unpack_ndarray_var_tys(unifier, ty))
|
|
||||||
}
|
|
||||||
_ => Either::Left(ty),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn split_as_ndarray_type(
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
ty: Type,
|
|
||||||
) -> (Type, Type) {
|
|
||||||
match split_scalar_or_ndarray_type(unifier, primitives, ty) {
|
|
||||||
Either::Left(dtype) => {
|
|
||||||
let ndims = unifier.get_fresh_literal(vec![SymbolValue::U64(0)], None);
|
|
||||||
(dtype, ndims)
|
|
||||||
}
|
|
||||||
Either::Right((dtype, ndims)) => (dtype, ndims),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef};
|
||||||
use crate::toplevel::numpy::{
|
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||||
extract_ndims, make_ndarray_ty, split_as_ndarray_type, split_scalar_or_ndarray_type,
|
|
||||||
unpack_ndarray_var_tys,
|
|
||||||
};
|
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
|
@ -523,11 +520,11 @@ pub fn typeof_binop(
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
let (lhs_dtype, lhs_ndims) = split_as_ndarray_type(unifier, primitives, lhs);
|
let lhs_dtype = arraylike_flatten_element_type(unifier, lhs);
|
||||||
let (rhs_dtype, rhs_ndims) = split_as_ndarray_type(unifier, primitives, rhs);
|
let rhs_dtype = arraylike_flatten_element_type(unifier, rhs);
|
||||||
|
|
||||||
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
let lhs_ndims = arraylike_get_ndims(unifier, lhs);
|
||||||
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
let rhs_ndims = arraylike_get_ndims(unifier, rhs);
|
||||||
|
|
||||||
if !(unifier.unioned(lhs_dtype, primitives.float)
|
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||||
&& unifier.unioned(rhs_dtype, primitives.float))
|
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||||
|
|
|
@ -17,7 +17,7 @@ use crate::{
|
||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||||
numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
TopLevelContext, TopLevelDef,
|
TopLevelContext, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::typedef::Mapping,
|
typecheck::typedef::Mapping,
|
||||||
|
@ -1554,8 +1554,7 @@ impl<'a> Inferencer<'a> {
|
||||||
|
|
||||||
let ndarray = self.fold_expr(args.remove(0))?;
|
let ndarray = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray.custom.unwrap());
|
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
|
||||||
let ndims = extract_ndims(self.unifier, ndims);
|
|
||||||
|
|
||||||
// Create a tuple of size `ndims` full of int32
|
// Create a tuple of size `ndims` full of int32
|
||||||
// TODO: Make it usize
|
// TODO: Make it usize
|
||||||
|
|
Loading…
Reference in New Issue