From 4948395ca2e045aebd0d2ab84003d83d66ada3e9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 18 Jul 2024 15:47:40 +0800 Subject: [PATCH] core/toplevel/type_annotation: Add handling for mismatching class def Primitive types only contain fields in its Type and not its TopLevelDef. This causes primitive object types to lack some fields. --- nac3artiq/demo/list_cmp.py | 24 +++ nac3core/src/toplevel/builtins.rs | 2 +- nac3core/src/toplevel/composer.rs | 10 +- nac3core/src/toplevel/helper.rs | 36 ++-- nac3core/src/toplevel/type_annotation.rs | 226 ++++++++++++++--------- nac3standalone/src/main.rs | 6 +- 6 files changed, 195 insertions(+), 109 deletions(-) create mode 100644 nac3artiq/demo/list_cmp.py diff --git a/nac3artiq/demo/list_cmp.py b/nac3artiq/demo/list_cmp.py new file mode 100644 index 000000000..7454532ff --- /dev/null +++ b/nac3artiq/demo/list_cmp.py @@ -0,0 +1,24 @@ +from min_artiq import * +from numpy import int32 + + +@nac3 +class EmptyList: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @rpc + def get_empty(self) -> list[int32]: + return [] + + @kernel + def run(self): + a: list[int32] = self.get_empty() + if a != []: + raise ValueError + + +if __name__ == "__main__": + EmptyList().run() diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index bcb3c433d..83256b149 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -564,7 +564,7 @@ impl<'a> BuiltinBuilder<'a> { match (&tld, prim.details()) { ( TopLevelDef::Class { name, object_id, .. }, - PrimDefDetails::PrimClass { name: exp_name }, + PrimDefDetails::PrimClass { name: exp_name, .. }, ) => { let exp_object_id = prim.id(); assert_eq!(name, &exp_name.into()); diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 5ba07df5b..58ae94fd6 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -766,6 +766,7 @@ impl TopLevelComposer { let target_ty = get_type_from_type_annotation_kinds( &temp_def_list, unifier, + primitives, &def, &mut subst_list, )?; @@ -936,6 +937,7 @@ impl TopLevelComposer { let ty = get_type_from_type_annotation_kinds( temp_def_list.as_ref(), unifier, + primitives_store, &type_annotation, &mut None, )?; @@ -1002,6 +1004,7 @@ impl TopLevelComposer { get_type_from_type_annotation_kinds( &temp_def_list, unifier, + primitives_store, &return_ty_annotation, &mut None, )? @@ -1622,6 +1625,7 @@ impl TopLevelComposer { let self_type = get_type_from_type_annotation_kinds( &def_list, unifier, + primitives_ty, &make_self_type_annotation(type_vars, *object_id), &mut None, )?; @@ -1803,7 +1807,11 @@ impl TopLevelComposer { let ty_ann = make_self_type_annotation(type_vars, *class_id); let self_ty = get_type_from_type_annotation_kinds( - &def_list, unifier, &ty_ann, &mut None, + &def_list, + unifier, + primitives_ty, + &ty_ann, + &mut None, )?; vars.extend(type_vars.iter().map(|ty| { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 5560f41c2..538e653e7 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -113,7 +113,7 @@ pub enum PrimDef { /// Associated details of a [`PrimDef`] pub enum PrimDefDetails { PrimFunction { name: &'static str, simple_name: &'static str }, - PrimClass { name: &'static str }, + PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type }, } impl PrimDef { @@ -155,15 +155,17 @@ impl PrimDef { #[must_use] pub fn name(&self) -> &'static str { match self.details() { - PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, + PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => { + name + } } } /// Get the associated details of this [`PrimDef`] #[must_use] pub fn details(self) -> PrimDefDetails { - fn class(name: &'static str) -> PrimDefDetails { - PrimDefDetails::PrimClass { name } + fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails { + PrimDefDetails::PrimClass { name, get_ty_fn } } fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { @@ -171,22 +173,22 @@ impl PrimDef { } match self { - PrimDef::Int32 => class("int32"), - PrimDef::Int64 => class("int64"), - PrimDef::Float => class("float"), - PrimDef::Bool => class("bool"), - PrimDef::None => class("none"), - PrimDef::Range => class("range"), - PrimDef::Str => class("str"), - PrimDef::Exception => class("Exception"), - PrimDef::UInt32 => class("uint32"), - PrimDef::UInt64 => class("uint64"), - PrimDef::Option => class("Option"), + PrimDef::Int32 => class("int32", |primitives| primitives.int32), + PrimDef::Int64 => class("int64", |primitives| primitives.int64), + PrimDef::Float => class("float", |primitives| primitives.float), + PrimDef::Bool => class("bool", |primitives| primitives.bool), + PrimDef::None => class("none", |primitives| primitives.none), + PrimDef::Range => class("range", |primitives| primitives.range), + PrimDef::Str => class("str", |primitives| primitives.str), + PrimDef::Exception => class("Exception", |primitives| primitives.exception), + PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32), + PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64), + PrimDef::Option => class("Option", |primitives| primitives.option), PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), - PrimDef::List => class("list"), - PrimDef::NDArray => class("ndarray"), + PrimDef::List => class("list", |primitives| primitives.list), + PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray), PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::FunInt32 => fun("int32", None), diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 95d5acad3..3f9b61a0b 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,8 +1,9 @@ use super::*; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimDef; +use crate::toplevel::helper::{PrimDef, PrimDefDetails}; use crate::typecheck::typedef::VarMap; use nac3parser::ast::Constant; +use strum::IntoEnumIterator; #[derive(Clone, Debug)] pub enum TypeAnnotation { @@ -357,6 +358,7 @@ pub fn parse_ast_to_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds( top_level_defs: &[Arc>], unifier: &mut Unifier, + primitives: &PrimitiveStore, ann: &TypeAnnotation, subst_list: &mut Option>, ) -> Result> { @@ -379,100 +381,141 @@ pub fn get_type_from_type_annotation_kinds( let param_ty = params .iter() .map(|x| { - get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) + get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + x, + subst_list, + ) }) .collect::, _>>()?; - let subst = { - // check for compatible range - // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check - let mut result = VarMap::new(); - for (tvar, p) in type_vars.iter().zip(param_ty) { - match unifier.get_ty(*tvar).as_ref() { - TypeEnum::TVar { - id, - range, - fields: None, - name, - loc, - is_const_generic: false, - } => { - let ok: bool = { - // create a temp type var and unify to check compatibility - p == *tvar || { - let temp = unifier.get_fresh_var_with_range( - range.as_slice(), - *name, - *loc, - ); - unifier.unify(temp.ty, p).is_ok() - } - }; - if ok { - result.insert(*id, p); - } else { - return Err(HashSet::from([format!( - "cannot apply type {} to type variable with id {:?}", - unifier.internal_stringify( - p, - &mut |id| format!("class{id}"), - &mut |id| format!("typevar{id}"), - &mut None - ), - *id - )])); - } - } + let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) { + // Primitive TopLevelDefs do not contain all fields that are present in their Type + // counterparts, so directly perform subst on the Type instead. - TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => { - let ty = range[0]; - let ok: bool = { - // create a temp type var and unify to check compatibility - p == *tvar || { - let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc); - unifier.unify(temp.ty, p).is_ok() - } - }; - if ok { - result.insert(*id, p); - } else { - return Err(HashSet::from([format!( - "cannot apply type {} to type variable {}", - unifier.stringify(p), - name.unwrap_or_else(|| format!("typevar{id}").into()), - )])); - } - } + let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else { + unreachable!() + }; - _ => unreachable!("must be generic type var"), + let base_ty = get_ty_fn(primitives); + let params = + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) { + params.clone() + } else { + unreachable!() + }; + + unifier + .subst( + get_ty_fn(primitives), + ¶ms + .iter() + .zip(param_ty) + .map(|(obj_tv, param)| (*obj_tv.0, param)) + .collect(), + ) + .unwrap_or(base_ty) + } else { + let subst = { + // check for compatible range + // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check + let mut result = VarMap::new(); + for (tvar, p) in type_vars.iter().zip(param_ty) { + match unifier.get_ty(*tvar).as_ref() { + TypeEnum::TVar { + id, + range, + fields: None, + name, + loc, + is_const_generic: false, + } => { + let ok: bool = { + // create a temp type var and unify to check compatibility + p == *tvar || { + let temp = unifier.get_fresh_var_with_range( + range.as_slice(), + *name, + *loc, + ); + unifier.unify(temp.ty, p).is_ok() + } + }; + if ok { + result.insert(*id, p); + } else { + return Err(HashSet::from([format!( + "cannot apply type {} to type variable with id {:?}", + unifier.internal_stringify( + p, + &mut |id| format!("class{id}"), + &mut |id| format!("typevar{id}"), + &mut None + ), + *id + )])); + } + } + + TypeEnum::TVar { + id, range, name, loc, is_const_generic: true, .. + } => { + let ty = range[0]; + let ok: bool = { + // create a temp type var and unify to check compatibility + p == *tvar || { + let temp = + unifier.get_fresh_const_generic_var(ty, *name, *loc); + unifier.unify(temp.ty, p).is_ok() + } + }; + if ok { + result.insert(*id, p); + } else { + return Err(HashSet::from([format!( + "cannot apply type {} to type variable {}", + unifier.stringify(p), + name.unwrap_or_else(|| format!("typevar{id}").into()), + )])); + } + } + + _ => unreachable!("must be generic type var"), + } + } + result + }; + // Class Attributes keep a copy with Class Definition and are not added to objects + let mut tobj_fields = methods + .iter() + .map(|(name, ty, _)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + // methods are immutable + (*name, (subst_ty, false)) + }) + .collect::>(); + tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*name, (subst_ty, *mutability)) + })); + let need_subst = !subst.is_empty(); + let ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields, + params: subst, + }); + + if need_subst { + if let Some(wl) = subst_list.as_mut() { + wl.push(ty); } } - result + + ty }; - // Class Attributes keep a copy with Class Definition and are not added to objects - let mut tobj_fields = methods - .iter() - .map(|(name, ty, _)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - // methods are immutable - (*name, (subst_ty, false)) - }) - .collect::>(); - tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*name, (subst_ty, *mutability)) - })); - let need_subst = !subst.is_empty(); - let ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - fields: tobj_fields, - params: subst, - }); - if need_subst { - if let Some(wl) = subst_list.as_mut() { - wl.push(ty); - } - } + Ok(ty) } TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), @@ -490,6 +533,7 @@ pub fn get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds( top_level_defs, unifier, + primitives, ty.as_ref(), subst_list, )?; @@ -499,7 +543,13 @@ pub fn get_type_from_type_annotation_kinds( let tys = tys .iter() .map(|x| { - get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) + get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + x, + subst_list, + ) }) .collect::, _>>()?; Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index c2a1d1947..78c7ff913 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -113,7 +113,9 @@ fn handle_typevar_definition( x, HashMap::new(), )?; - get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None) + get_type_from_type_annotation_kinds( + def_list, unifier, primitives, &ty, &mut None, + ) }) .collect::, _>>()?; let loc = func.location; @@ -152,7 +154,7 @@ fn handle_typevar_definition( HashMap::new(), )?; let constraint = - get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; + get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?; let loc = func.location; Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)