From c28166efb8ef91cb7be301c91f0834be3202fc72 Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 9 Aug 2024 16:43:31 +0800 Subject: [PATCH] WIP: core/typecheck: after np.newaxis and ... --- .../src/codegen/structure/ndarray/indexing.rs | 2 + nac3core/src/toplevel/mod.rs | 1 + nac3core/src/toplevel/numpy.rs | 2 - nac3core/src/toplevel/option.rs | 46 +++++++++++++++++++ nac3core/src/typecheck/function_check.rs | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 19 +++++++- nac3core/src/typecheck/typedef/mod.rs | 3 +- 7 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 nac3core/src/toplevel/option.rs diff --git a/nac3core/src/codegen/structure/ndarray/indexing.rs b/nac3core/src/codegen/structure/ndarray/indexing.rs index d88cf81e..5d767140 100644 --- a/nac3core/src/codegen/structure/ndarray/indexing.rs +++ b/nac3core/src/codegen/structure/ndarray/indexing.rs @@ -335,6 +335,8 @@ pub mod util { RustNDIndex::Slice(RustUserSlice { start, stop, step }) } else { + todo!("implement me"); + // Anything else that is not a slice (might be illegal values), // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 7dfd8373..b9f209be 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -31,6 +31,7 @@ pub mod builtins; pub mod composer; pub mod helper; pub mod numpy; +pub mod option; pub mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 015b4eac..0d14a61c 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{ symbol_resolver::SymbolValue, toplevel::helper::PrimDef, diff --git a/nac3core/src/toplevel/option.rs b/nac3core/src/toplevel/option.rs new file mode 100644 index 00000000..1dd7f343 --- /dev/null +++ b/nac3core/src/toplevel/option.rs @@ -0,0 +1,46 @@ +use itertools::Itertools; + +use crate::{ + toplevel::helper::PrimDef, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{Type, TypeEnum, Unifier, VarMap}, + }, +}; + +// TODO: This entire module is duplicated code (numpy.rs also has these kinds of things) + +/// Creates a `option` [`Type`] with the given type arguments. +/// +/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not +/// specialized. +/// * `ndims` - The number of dimensions of the `option`, or [`None`] if the type variable is not +/// specialized. +pub fn make_option_ty( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + dtype: Option, +) -> Type { + subst_option_tvars(unifier, primitives.option, dtype) +} + +/// Substitutes type variables in `option`. +/// +/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not +/// specialized. +pub fn subst_option_tvars(unifier: &mut Unifier, option: Type, dtype: Option) -> Type { + let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(option) else { + panic!("Expected `option` to be TObj, but got {}", unifier.stringify(option)) + }; + debug_assert_eq!(*obj_id, PrimDef::Option.id()); + + let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec(); + debug_assert_eq!(tvar_ids.len(), 1); + + let mut tvar_subst = VarMap::new(); + if let Some(dtype) = dtype { + tvar_subst.insert(tvar_ids[0], dtype); + } + + unifier.subst(option, &tvar_subst).unwrap_or(option) +} diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index b3790994..fe633253 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -80,7 +80,7 @@ impl<'a> Inferencer<'a> { return Err(HashSet::from([format!( "expected concrete type at {} but got {}", expr.location, - self.unifier.get_ty(*ty).get_type_name() + self.unifier.stringify(*ty) )])); } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 3f7713d4..23408166 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -11,6 +11,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::option::make_option_ty; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -2264,16 +2265,30 @@ impl<'a> Inferencer<'a> { } } ExprKind::Constant { value: Constant::Ellipsis, .. } => { - // Handle `...`. Do nothing. + // Handle `...`. + + // See https://git.m-labs.hk/M-Labs/nac3/issues/486 + // Force `...` to have `()` (completely bogus) to make it concrete. + let empty_tuple = TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }; + let empty_tuple = self.unifier.add_ty(empty_tuple); + self.unify(index.custom.unwrap(), empty_tuple, &index.location)?; } ExprKind::Name { id, .. } if id == &"none".into() => { // Handle `np.newaxis` / `None`. dims_to_subtract -= 1; + + // "none" itself has type `Option[T]`, and since we have a stray `T` (non-concrete type). + // We will force the type to be `Option[()]` to make it concrete. (TODO: is there a void type?) + let empty_tuple = TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }; + let empty_tuple = self.unifier.add_ty(empty_tuple); + let expected_type = + make_option_ty(self.unifier, self.primitives, Some(empty_tuple)); + self.unify(index.custom.unwrap(), expected_type, &index.location)?; } _ => { // Treat anything else as an integer index, and force unify their type to int32. - self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?; dims_to_subtract += 1; + self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?; } } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e125ff80..e93656e3 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -621,8 +621,7 @@ impl Unifier { } pub fn unify_call( - &mut self, - call: &Call, + &mut self, call: &Call, b: Type, signature: &FunSignature, ) -> Result<(), TypeError> {