diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 8efcd43..aa01c7d 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,6 +1,13 @@ -use crate::location::Location; -use crate::top_level::DefinitionId; -use crate::typecheck::typedef::Type; +use std::cell::RefCell; +use std::collections::HashMap; + +use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef}; +use crate::typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{Type, Unifier}, +}; +use crate::{location::Location, typecheck::typedef::TypeEnum}; +use itertools::{chain, izip}; use rustpython_parser::ast::Expr; #[derive(Clone, PartialEq)] @@ -15,11 +22,121 @@ pub enum SymbolValue { } pub trait SymbolResolver { - fn get_symbol_type(&self, str: &str) -> Option; - fn parse_type_name(&self, expr: &Expr<()>) -> Option; + fn get_symbol_type( + &self, + unifier: &mut Unifier, + primitives: &PrimitiveStore, + str: &str, + ) -> Option; fn get_identifier_def(&self, str: &str) -> DefinitionId; fn get_symbol_value(&self, str: &str) -> Option; fn get_symbol_location(&self, str: &str) -> Option; - fn get_module_resolver(&self, module_name: &str) -> Option<&dyn SymbolResolver>; // NOTE: for getting imported modules' symbol resolver? - // handle function call etc. + // handle function call etc. +} + +impl dyn SymbolResolver { + pub fn parse_type_annotation( + &self, + top_level: &TopLevelContext, + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &Expr, + ) -> Result { + use rustpython_parser::ast::ExprKind::*; + match &expr.node { + Name { id, .. } => match id.as_str() { + "int32" => Ok(primitives.int32), + "int64" => Ok(primitives.int64), + "float" => Ok(primitives.float), + "bool" => Ok(primitives.bool), + "None" => Ok(primitives.none), + x => { + let obj_id = self.get_identifier_def(x); + let defs = top_level.definitions.read(); + let def = defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); + } + let fields = RefCell::new( + chain( + fields.iter().map(|(k, v)| (k.clone(), *v)), + methods.iter().map(|(k, v, _)| (k.clone(), *v)), + ) + .collect(), + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) + } else { + Err("Cannot use function name as type".into()) + } + } + }, + Subscript { value, slice, .. } => { + if let Name { id, .. } = &value.node { + if id == "virtual" { + let ty = + self.parse_type_annotation(top_level, unifier, primitives, slice)?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } else { + let types = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + self.parse_type_annotation(top_level, unifier, primitives, v) + }) + .collect::, _>>()? + } else { + vec![self.parse_type_annotation(top_level, unifier, primitives, slice)?] + }; + + let obj_id = self.get_identifier_def(id); + let defs = top_level.definitions.read(); + let def = defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if types.len() != type_vars.len() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )); + } + let mut subst = HashMap::new(); + for (var, ty) in izip!(type_vars.iter(), types.iter()) { + let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { + *id + } else { + unreachable!() + }; + subst.insert(id, *ty); + } + let mut fields = fields + .iter() + .map(|(attr, ty)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (attr.clone(), ty) + }) + .collect::>(); + fields.extend(methods.iter().map(|(attr, ty, _)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (attr.clone(), ty) + })); + let fields = RefCell::new(fields); + Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) + } else { + Err("Cannot use function name as type".into()) + } + } + } else { + Err("unsupported type expression".into()) + } + } + _ => Err("unsupported type expression".into()), + } + } } diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index 9c60795..5497e90 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -440,37 +440,3 @@ impl TopLevelComposer { } } -pub fn parse_type_var( - input: &ast::Expr, - resolver: &dyn SymbolResolver, -) -> Result { - match &input.node { - ast::ExprKind::Name { id, .. } => resolver - .get_symbol_type(id) - .ok_or_else(|| "unknown type variable identifer".to_string()), - - ast::ExprKind::Attribute { value, attr, .. } => { - if let ast::ExprKind::Name { id, .. } = &value.node { - let next_resolver = resolver - .get_module_resolver(id) - .ok_or_else(|| "unknown imported module".to_string())?; - next_resolver - .get_symbol_type(attr) - .ok_or_else(|| "unknown type variable identifer".to_string()) - } else { - unimplemented!() - // recursively resolve attr thing, FIXME: new problem: how do we handle this? - // # A.py - // class A: - // T = TypeVar('T', int, bool) - // pass - // # B.py - // import A - // class B(Generic[A.A.T]): - // pass - } - } - - _ => Err("not supported".into()), - } -} diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d81c0d0..cdd83a5 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc}; use super::magic_methods::*; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; -use crate::symbol_resolver::SymbolResolver; +use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext}; use itertools::izip; use rustpython_parser::ast::{ self, @@ -44,6 +44,7 @@ pub struct FunctionData { } pub struct Inferencer<'a> { + pub top_level: &'a TopLevelContext, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, @@ -81,11 +82,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } else { None }; - let annotation_type = self - .function_data - .resolver - .parse_type_name(annotation.as_ref()) - .ok_or_else(|| "cannot parse type name".to_string())?; + let annotation_type = self.function_data.resolver.parse_type_annotation( + self.top_level, + self.unifier, + &self.primitives, + annotation.as_ref(), + )?; self.unifier.unify(annotation_type, target.custom.unwrap())?; let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); Located { @@ -235,6 +237,7 @@ impl<'a> Inferencer<'a> { primitives: self.primitives, virtual_checks: self.virtual_checks, calls: self.calls, + top_level: self.top_level, variable_mapping, }; let fun = FunSignature { @@ -275,6 +278,7 @@ impl<'a> Inferencer<'a> { function_data: self.function_data, unifier: self.unifier, virtual_checks: self.virtual_checks, + top_level: self.top_level, variable_mapping, primitives: self.primitives, calls: self.calls, @@ -336,10 +340,12 @@ impl<'a> Inferencer<'a> { } let arg0 = self.fold_expr(args.remove(0))?; let ty = if let Some(arg) = args.pop() { - self.function_data - .resolver - .parse_type_name(&arg) - .ok_or_else(|| "error parsing type".to_string())? + self.function_data.resolver.parse_type_annotation( + self.top_level, + self.unifier, + self.primitives, + &arg, + )? } else { self.unifier.get_fresh_var().0 }; @@ -412,11 +418,15 @@ impl<'a> Inferencer<'a> { if let Some(ty) = self.variable_mapping.get(id) { Ok(*ty) } else { - Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| { - let ty = self.unifier.get_fresh_var().0; - self.variable_mapping.insert(id.to_string(), ty); - ty - })) + Ok(self + .function_data + .resolver + .get_symbol_type(self.unifier, self.primitives, id) + .unwrap_or_else(|| { + let ty = self.unifier.get_fresh_var().0; + self.variable_mapping.insert(id.to_string(), ty); + ty + })) } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 44775c8..8fbd88f 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,30 +1,23 @@ use super::super::typedef::*; use super::*; -use crate::location::Location; use crate::symbol_resolver::*; use crate::top_level::DefinitionId; +use crate::{location::Location, top_level::TopLevelDef}; use indoc::indoc; use itertools::zip; -use rustpython_parser::ast; +use parking_lot::RwLock; use rustpython_parser::parser::parse_program; use test_case::test_case; struct Resolver { - identifier_mapping: HashMap, + id_to_type: HashMap, + id_to_def: HashMap, class_names: HashMap, } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, str: &str) -> Option { - self.identifier_mapping.get(str).cloned() - } - - fn parse_type_name(&self, ty: &ast::Expr<()>) -> Option { - if let ExprKind::Name { id, .. } = &ty.node { - self.class_names.get(id).cloned() - } else { - unimplemented!() - } + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { + self.id_to_type.get(str).cloned() } fn get_symbol_value(&self, _: &str) -> Option { @@ -35,12 +28,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, _: &str) -> DefinitionId { - unimplemented!() - } - - fn get_module_resolver(&self, _: &str) -> Option<&dyn SymbolResolver> { - unimplemented!() + fn get_identifier_def(&self, id: &str) -> DefinitionId { + self.id_to_def.get(id).cloned().unwrap() } } @@ -52,6 +41,7 @@ struct TestEnvironment { pub identifier_mapping: HashMap, pub virtual_checks: Vec<(Type, Type)>, pub calls: HashMap>, + pub top_level: TopLevelContext, } impl TestEnvironment { @@ -101,11 +91,17 @@ impl TestEnvironment { identifier_mapping.insert("None".into(), none); let resolver = Arc::new(Resolver { - identifier_mapping: identifier_mapping.clone(), + id_to_type: identifier_mapping.clone(), + id_to_def: Default::default(), class_names: Default::default(), }) as Arc; TestEnvironment { + top_level: TopLevelContext { + definitions: Default::default(), + unifiers: Default::default(), + conetexts: Default::default(), + }, unifier, function_data: FunctionData { resolver, @@ -123,6 +119,7 @@ impl TestEnvironment { fn new() -> TestEnvironment { let mut unifier = Unifier::new(); let mut identifier_mapping = HashMap::new(); + let mut top_level_defs = Vec::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), fields: HashMap::new().into(), @@ -149,6 +146,16 @@ impl TestEnvironment { params: HashMap::new(), }); identifier_mapping.insert("None".into(), none); + for i in 0..5 { + top_level_defs.push(RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(i), + type_vars: Default::default(), + fields: Default::default(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + })); + } let primitives = PrimitiveStore { int32, int64, float, bool, none }; @@ -159,6 +166,14 @@ impl TestEnvironment { fields: [("a".into(), v0)].iter().cloned().collect::>().into(), params: [(id, v0)].iter().cloned().collect(), }); + top_level_defs.push(RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(5), + type_vars: vec![v0], + fields: [("a".into(), v0)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + })); identifier_mapping.insert( "Foo".into(), @@ -183,6 +198,14 @@ impl TestEnvironment { .into(), params: Default::default(), }); + top_level_defs.push(RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(6), + type_vars: Default::default(), + fields: [("a".into(), int32), ("b".into(), fun)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + })); identifier_mapping.insert( "Bar".into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { @@ -201,6 +224,14 @@ impl TestEnvironment { .into(), params: Default::default(), }); + top_level_defs.push(RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(7), + type_vars: Default::default(), + fields: [("a".into(), bool), ("b".into(), fun)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + })); identifier_mapping.insert( "Bar2".into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { @@ -225,12 +256,28 @@ impl TestEnvironment { .cloned() .collect(); - let resolver = - Arc::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names }) - as Arc; + let top_level = TopLevelContext { + definitions: Arc::new(RwLock::new(top_level_defs)), + unifiers: Default::default(), + conetexts: Default::default(), + }; + + let resolver = Arc::new(Resolver { + id_to_type: identifier_mapping.clone(), + id_to_def: [ + ("Foo".into(), DefinitionId(5)), + ("Bar".into(), DefinitionId(6)), + ("Bar2".into(), DefinitionId(7)), + ] + .iter() + .cloned() + .collect(), + class_names, + }) as Arc; TestEnvironment { unifier, + top_level, function_data: FunctionData { resolver, bound_variables: Vec::new(), @@ -246,6 +293,7 @@ impl TestEnvironment { fn get_inferencer(&mut self) -> Inferencer { Inferencer { + top_level: &self.top_level, function_data: &mut self.function_data, unifier: &mut self.unifier, variable_mapping: Default::default(), diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 38e2a9f..5126d84 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -645,7 +645,7 @@ impl Unifier { /// If this returns Some(T), T would be the substituted type. /// If this returns None, the result type would be the original type /// (no substitution has to be done). - fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { + pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { use TypeVarMeta::*; let ty = self.unification_table.probe_value(a).clone(); // this function would only be called when we instantiate functions.