WIP: core/typecheck: after np.newaxis and ...

This commit is contained in:
lyken 2024-08-09 16:43:31 +08:00
parent 2546053013
commit c28166efb8
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
7 changed files with 68 additions and 7 deletions

View File

@ -335,6 +335,8 @@ pub mod util {
RustNDIndex::Slice(RustUserSlice { start, stop, step })
} else {
todo!("implement me");
// Anything else that is not a slice (might be illegal values),
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(

View File

@ -31,6 +31,7 @@ pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod option;
pub mod type_annotation;
use composer::*;
use type_annotation::*;

View File

@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::{
symbol_resolver::SymbolValue,
toplevel::helper::PrimDef,

View File

@ -0,0 +1,46 @@
use itertools::Itertools;
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
// TODO: This entire module is duplicated code (numpy.rs also has these kinds of things)
/// Creates a `option` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `option`, or [`None`] if the type variable is not
/// specialized.
pub fn make_option_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
) -> Type {
subst_option_tvars(unifier, primitives.option, dtype)
}
/// Substitutes type variables in `option`.
///
/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_option_tvars(unifier: &mut Unifier, option: Type, dtype: Option<Type>) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(option) else {
panic!("Expected `option` to be TObj, but got {}", unifier.stringify(option))
};
debug_assert_eq!(*obj_id, PrimDef::Option.id());
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 1);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
unifier.subst(option, &tvar_subst).unwrap_or(option)
}

View File

@ -80,7 +80,7 @@ impl<'a> Inferencer<'a> {
return Err(HashSet::from([format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
self.unifier.stringify(*ty)
)]));
}
}

View File

@ -11,6 +11,7 @@ use super::{
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
},
};
use crate::toplevel::option::make_option_ty;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
@ -2264,16 +2265,30 @@ impl<'a> Inferencer<'a> {
}
}
ExprKind::Constant { value: Constant::Ellipsis, .. } => {
// Handle `...`. Do nothing.
// Handle `...`.
// See https://git.m-labs.hk/M-Labs/nac3/issues/486
// Force `...` to have `()` (completely bogus) to make it concrete.
let empty_tuple = TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false };
let empty_tuple = self.unifier.add_ty(empty_tuple);
self.unify(index.custom.unwrap(), empty_tuple, &index.location)?;
}
ExprKind::Name { id, .. } if id == &"none".into() => {
// Handle `np.newaxis` / `None`.
dims_to_subtract -= 1;
// "none" itself has type `Option[T]`, and since we have a stray `T` (non-concrete type).
// We will force the type to be `Option[()]` to make it concrete. (TODO: is there a void type?)
let empty_tuple = TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false };
let empty_tuple = self.unifier.add_ty(empty_tuple);
let expected_type =
make_option_ty(self.unifier, self.primitives, Some(empty_tuple));
self.unify(index.custom.unwrap(), expected_type, &index.location)?;
}
_ => {
// Treat anything else as an integer index, and force unify their type to int32.
self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?;
dims_to_subtract += 1;
self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?;
}
}
}

View File

@ -621,8 +621,7 @@ impl Unifier {
}
pub fn unify_call(
&mut self,
call: &Call,
&mut self, call: &Call,
b: Type,
signature: &FunSignature,
) -> Result<(), TypeError> {