diff --git a/nac3core/src/typecheck/symbol_resolver.rs b/nac3core/src/typecheck/symbol_resolver.rs index 96003410..669f7632 100644 --- a/nac3core/src/typecheck/symbol_resolver.rs +++ b/nac3core/src/typecheck/symbol_resolver.rs @@ -2,11 +2,6 @@ use super::typedef::Type; use super::location::Location; use rustpython_parser::ast::Expr; -pub enum SymbolType { - TypeName(Type), - Identifier(Type), -} - pub enum SymbolValue<'a> { I32(i32), I64(i64), diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 728bb7cc..7cd72b41 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -14,20 +14,23 @@ use rustpython_parser::ast::{ Arguments, Comprehension, ExprKind, Located, Location, }; +#[cfg(test)] +mod test; + pub struct PrimitiveStore { - int32: Type, - int64: Type, - float: Type, - bool: Type, - none: Type, + pub int32: Type, + pub int64: Type, + pub float: Type, + pub bool: Type, + pub none: Type, } pub struct Inferencer<'a> { - resolver: &'a mut Box, - unifier: &'a mut Unifier, - variable_mapping: HashMap, - calls: &'a mut Vec>, - primitives: &'a PrimitiveStore, + pub resolver: &'a mut Box, + pub unifier: &'a mut Unifier, + pub variable_mapping: HashMap, + pub calls: &'a mut Vec>, + pub primitives: &'a PrimitiveStore, } struct NaiveFolder(); @@ -69,7 +72,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { .resolver .parse_type_name(annotation.as_ref()) .ok_or_else(|| "cannot parse type name".to_string())?; - self.unifier.unify(annotation_type, target.custom.unwrap())?; + self.unifier + .unify(annotation_type, target.custom.unwrap())?; let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); Located { location: node.location, @@ -102,7 +106,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } ast::StmtKind::AnnAssign { .. } => {} - _ => return Err("Unsupported statement type".to_string()) + _ => return Err("Unsupported statement type".to_string()), }; Ok(stmt) } @@ -358,7 +362,7 @@ impl<'a> Inferencer<'a> { if id == "int64" && args.len() == 1 { if let ExprKind::Constant { value: ast::Constant::Int(val), - .. + kind, } = &args[0].node { let int64: Result = val.try_into(); @@ -377,7 +381,14 @@ impl<'a> Inferencer<'a> { location: func.location, node: ExprKind::Name { id, ctx }, }), - args: vec![self.fold_expr(args.pop().unwrap())?], + args: vec![Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(val.clone()), + kind: kind.clone(), + }, + }], keywords: vec![], }, }); diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs new file mode 100644 index 00000000..23f99644 --- /dev/null +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -0,0 +1,163 @@ +use super::super::location::Location; +use super::super::symbol_resolver::*; +use super::super::typedef::*; +use super::*; +use indoc::indoc; +use rustpython_parser::ast; +use rustpython_parser::parser::parse_program; +use test_case::test_case; + +struct Resolver { + type_mapping: HashMap, +} + +impl SymbolResolver for Resolver { + fn get_symbol_type(&mut self, str: &str) -> Option { + self.type_mapping.get(str).cloned() + } + + fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option { + unimplemented!() + } + + fn get_symbol_value(&mut self, _: &str) -> Option { + unimplemented!() + } + + fn get_symbol_location(&mut self, _: &str) -> Option { + unimplemented!() + } +} + +struct TestEnvironment { + pub unifier: Unifier, + pub resolver: Box, + pub calls: Vec>, + pub primitives: PrimitiveStore, + pub id_to_name: HashMap, +} + +impl TestEnvironment { + fn new() -> TestEnvironment { + let mut unifier = Unifier::new(); + let mut type_mapping = HashMap::new(); + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: 0, + fields: HashMap::new(), + params: HashMap::new(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: 1, + fields: HashMap::new(), + params: HashMap::new(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: 2, + fields: HashMap::new(), + params: HashMap::new(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: 3, + fields: HashMap::new(), + params: HashMap::new(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: 4, + fields: HashMap::new(), + params: HashMap::new(), + }); + type_mapping.insert("int32".into(), int32); + type_mapping.insert("int64".into(), int64); + type_mapping.insert("float".into(), float); + type_mapping.insert("bool".into(), bool); + type_mapping.insert("none".into(), none); + + let primitives = PrimitiveStore { + int32, + int64, + float, + bool, + none, + }; + + let (v0, id) = unifier.get_fresh_var(); + type_mapping.insert( + "foo".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: 5, + fields: [("a".into(), v0)].iter().cloned().collect(), + params: [(id, v0)].iter().cloned().collect(), + }), + ); + + let id_to_name = [ + (0, "int32".to_string()), + (1, "int64".to_string()), + (2, "float".to_string()), + (3, "bool".to_string()), + (4, "none".to_string()), + (5, "Foo".to_string()), + ] + .iter() + .cloned() + .collect(); + + let resolver = Box::new(Resolver { type_mapping }) as Box; + + TestEnvironment { + unifier, + resolver, + primitives, + id_to_name, + calls: Vec::new(), + } + } + + fn get_inferencer(&mut self) -> Inferencer { + Inferencer { + resolver: &mut self.resolver, + unifier: &mut self.unifier, + variable_mapping: Default::default(), + calls: &mut self.calls, + primitives: &mut self.primitives, + } + } +} + +#[test_case(indoc! {" + a = 1234 + b = int64(2147483648) + c = 1.234 + d = True + "}, + [("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect() + ; "primitives test")] +#[test_case(indoc! {" + a = lambda x, y: x + b = lambda x: a(x, x) + c = 1.234 + d = b(c) + "}, + [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect() + ; "lambda test")] +fn test_basic(source: &str, mapping: HashMap<&str, &str>) { + let mut env = TestEnvironment::new(); + let id_to_name = std::mem::take(&mut env.id_to_name); + let mut inferencer = env.get_inferencer(); + let statements = parse_program(source).unwrap(); + statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + for (k, v) in mapping.iter() { + let ty = inferencer.variable_mapping.get(*k).unwrap(); + let name = inferencer.unifier.stringify( + *ty, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); + } +} + diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index ef034f1e..4d00c11e 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -8,7 +8,7 @@ use std::ops::Deref; use std::rc::Rc; #[cfg(test)] -mod test_typedef; +mod test; #[derive(Copy, Clone, PartialEq, Eq, Debug)] /// Handle for a type, implementated as a key in the unification table. @@ -217,10 +217,14 @@ impl Unifier { } TypeEnum::TObj { obj_id, params, .. } => { let name = obj_to_name(*obj_id); - let mut params = params - .values() - .map(|v| self.stringify(*v, obj_to_name, var_to_name)); - format!("{}[{}]", name, params.join(", ")) + if params.len() > 0 { + let mut params = params + .values() + .map(|v| self.stringify(*v, obj_to_name, var_to_name)); + format!("{}[{}]", name, params.join(", ")) + } else { + name + } } TypeEnum::TCall { .. } => "call".to_owned(), TypeEnum::TFunc(signature) => { @@ -432,6 +436,9 @@ impl Unifier { return Err(format!("Unknown keyword argument {}", k)); } } + if !required.is_empty() { + return Err("Expected more arguments".to_string()); + } self.unify(*ret, signature.ret)?; *fun.borrow_mut() = Some(instantiated); } diff --git a/nac3core/src/typecheck/typedef/test_typedef.rs b/nac3core/src/typecheck/typedef/test.rs similarity index 100% rename from nac3core/src/typecheck/typedef/test_typedef.rs rename to nac3core/src/typecheck/typedef/test.rs