From d46a4b2d38729ea3a4deb43a8862eab7acb67bdc Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 12 Aug 2021 10:25:32 +0800 Subject: [PATCH] symbol_resolver: fixed type variable handling --- nac3core/src/symbol_resolver.rs | 63 ++++++++++++------- .../src/typecheck/type_inferencer/test.rs | 4 +- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index aa01c7dcf..b6e3b7929 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -22,19 +22,22 @@ pub enum SymbolValue { } pub trait SymbolResolver { + // get type of type variable identifier or top-level function type fn get_symbol_type( &self, unifier: &mut Unifier, primitives: &PrimitiveStore, str: &str, ) -> Option; - fn get_identifier_def(&self, str: &str) -> DefinitionId; + // get the top-level definition of identifiers + fn get_identifier_def(&self, str: &str) -> Option; fn get_symbol_value(&self, str: &str) -> Option; fn get_symbol_location(&self, str: &str) -> Option; // handle function call etc. } impl dyn SymbolResolver { + // convert type annotation into type pub fn parse_type_annotation( &self, top_level: &TopLevelContext, @@ -52,29 +55,41 @@ impl dyn SymbolResolver { "None" => Ok(primitives.none), x => { let obj_id = self.get_identifier_def(x); - let defs = top_level.definitions.read(); - let def = defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if !type_vars.is_empty() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got 0", - type_vars.len() - )); + if let Some(obj_id) = obj_id { + let defs = top_level.definitions.read(); + let def = defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); + } + let fields = RefCell::new( + chain( + fields.iter().map(|(k, v)| (k.clone(), *v)), + methods.iter().map(|(k, v, _)| (k.clone(), *v)), + ) + .collect(), + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) + } else { + Err("Cannot use function name as type".into()) } - let fields = RefCell::new( - chain( - fields.iter().map(|(k, v)| (k.clone(), *v)), - methods.iter().map(|(k, v, _)| (k.clone(), *v)), - ) - .collect(), - ); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: Default::default(), - })) } else { - Err("Cannot use function name as type".into()) + // it could be a type variable + let ty = self + .get_symbol_type(unifier, primitives, x) + .ok_or_else(|| "Cannot use function name as type".to_owned())?; + if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { + Ok(ty) + } else { + Err(format!("Unknown type annotation {}", x)) + } } } }, @@ -95,7 +110,9 @@ impl dyn SymbolResolver { vec![self.parse_type_annotation(top_level, unifier, primitives, slice)?] }; - let obj_id = self.get_identifier_def(id); + let obj_id = self + .get_identifier_def(id) + .ok_or_else(|| format!("Unknown type annotation {}", id))?; let defs = top_level.definitions.read(); let def = defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 8fbd88f90..ee65dbe3e 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -28,8 +28,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, id: &str) -> DefinitionId { - self.id_to_def.get(id).cloned().unwrap() + fn get_identifier_def(&self, id: &str) -> Option { + self.id_to_def.get(id).cloned() } }