diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 6b15049d4..50ba54edb 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -11,8 +11,8 @@ inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", feat rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" } indoc = "1.0" ena = "0.14" +itertools = "0.10.1" [dev-dependencies] test-case = "1.2.0" -itertools = "0.10.1" diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index fceaaeea8..74440eefe 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -6,6 +6,7 @@ extern crate inkwell; extern crate rustpython_parser; extern crate indoc; extern crate ena; +extern crate itertools; mod typecheck; diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index bb470cd01..cfab64c16 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -4,3 +4,4 @@ mod magic_methods; pub mod symbol_resolver; mod test_typedef; pub mod typedef; +pub mod type_inferencer; diff --git a/nac3core/src/typecheck/symbol_resolver.rs b/nac3core/src/typecheck/symbol_resolver.rs index a6eff4409..accb1aaee 100644 --- a/nac3core/src/typecheck/symbol_resolver.rs +++ b/nac3core/src/typecheck/symbol_resolver.rs @@ -16,8 +16,8 @@ pub enum SymbolValue<'a> { } pub trait SymbolResolver { - fn get_symbol_type(&self, str: &str) -> Option; - fn get_symbol_value(&self, str: &str) -> Option; - fn get_symbol_location(&self, str: &str) -> Option; + fn get_symbol_type(&mut self, str: &str) -> Option; + fn get_symbol_value(&mut self, str: &str) -> Option; + fn get_symbol_location(&mut self, str: &str) -> Option; // handle function call etc. } diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs new file mode 100644 index 000000000..3ccc3016f --- /dev/null +++ b/nac3core/src/typecheck/type_inferencer.rs @@ -0,0 +1,227 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::convert::TryInto; +use std::iter::once; +use std::rc::Rc; + +use super::magic_methods::*; +use super::symbol_resolver::{SymbolResolver, SymbolType}; +use super::typedef::{Call, Type, TypeEnum, Unifier}; +use itertools::izip; +use rustpython_parser::ast::{self, fold::Fold}; + +pub struct PrimitiveStore { + int32: Type, + int64: Type, + float: Type, + bool: Type, + none: Type, +} + +pub struct Inferencer<'a> { + resolver: &'a mut Box, + unifier: &'a mut Unifier, + variable_mapping: &'a mut HashMap, + calls: &'a mut Vec>, + primitives: &'a PrimitiveStore, +} + +impl<'a> Fold<()> for Inferencer<'a> { + type TargetU = Option; + type Error = String; + + fn map_user(&mut self, _: ()) -> Result { + Ok(None) + } +} + +type InferenceResult = Result; + +impl<'a> Inferencer<'a> { + fn build_method_call( + &mut self, + method: String, + obj: Type, + params: Vec, + ret: Type, + ) -> InferenceResult { + let call = Rc::new(Call { + posargs: params, + kwargs: HashMap::new(), + ret, + fun: RefCell::new(None), + }); + self.calls.push(call.clone()); + let call = self.unifier.add_ty(TypeEnum::TCall { calls: vec![call] }); + let fields = once((method, call)).collect(); + let record = self.unifier.add_ty(TypeEnum::TRecord { fields }); + self.unifier.unify(obj, record)?; + Ok(ret) + } + + fn infer_identifier(&mut self, id: &str) -> InferenceResult { + if let Some(ty) = self.variable_mapping.get(id) { + Ok(*ty) + } else { + match self.resolver.get_symbol_type(id) { + Some(SymbolType::TypeName(_)) => { + Err("Expected expression instead of type".to_string()) + } + Some(SymbolType::Identifier(ty)) => Ok(ty), + None => { + let ty = self.unifier.get_fresh_var().0; + self.variable_mapping.insert(id.to_string(), ty); + Ok(ty) + } + } + } + } + + fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult { + match constant { + ast::Constant::Bool(_) => Ok(self.primitives.bool), + ast::Constant::Int(val) => { + let int32: Result = val.try_into(); + // int64 would be handled separately in functions + if int32.is_ok() { + Ok(self.primitives.int64) + } else { + Err("Integer out of bound".into()) + } + } + ast::Constant::Float(_) => Ok(self.primitives.float), + ast::Constant::Tuple(vals) => { + let ty: Result, _> = vals.iter().map(|x| self.infer_constant(x)).collect(); + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) + } + _ => Err("not supported".into()), + } + } + + fn infer_list(&mut self, elts: &[ast::Expr>]) -> InferenceResult { + let (ty, _) = self.unifier.get_fresh_var(); + for t in elts.iter() { + self.unifier.unify(ty, t.custom.unwrap())?; + } + Ok(ty) + } + + fn infer_tuple(&mut self, elts: &[ast::Expr>]) -> InferenceResult { + let ty = elts.iter().map(|x| x.custom.unwrap()).collect(); + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) + } + + fn infer_attribute(&mut self, value: &ast::Expr>, attr: &str) -> InferenceResult { + let (attr_ty, _) = self.unifier.get_fresh_var(); + let fields = once((attr.to_string(), attr_ty)).collect(); + let parent = self.unifier.add_ty(TypeEnum::TRecord { fields }); + self.unifier.unify(value.custom.unwrap(), parent)?; + Ok(attr_ty) + } + + fn infer_bool_ops(&mut self, values: &[ast::Expr>]) -> InferenceResult { + let b = self.primitives.bool; + for v in values { + self.unifier.unify(v.custom.unwrap(), b)?; + } + Ok(b) + } + + fn infer_bin_ops( + &mut self, + left: &ast::Expr>, + op: &ast::Operator, + right: &ast::Expr>, + ) -> InferenceResult { + let method = binop_name(op); + let ret = self.unifier.get_fresh_var().0; + self.build_method_call( + method.to_string(), + left.custom.unwrap(), + vec![right.custom.unwrap()], + ret, + ) + } + + fn infer_unary_ops( + &mut self, + op: &ast::Unaryop, + operand: &ast::Expr>, + ) -> InferenceResult { + let method = unaryop_name(op); + let ret = self.unifier.get_fresh_var().0; + self.build_method_call(method.to_string(), operand.custom.unwrap(), vec![], ret) + } + + fn infer_compare( + &mut self, + left: &ast::Expr>, + ops: &[ast::Cmpop], + comparators: &[ast::Expr>], + ) -> InferenceResult { + let boolean = self.primitives.bool; + for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { + let method = comparison_name(c) + .ok_or_else(|| "unsupported comparator".to_string())? + .to_string(); + self.build_method_call(method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?; + } + Ok(boolean) + } + + fn infer_subscript( + &mut self, + value: &ast::Expr>, + slice: &ast::Expr>, + ) -> InferenceResult { + let ty = self.unifier.get_fresh_var().0; + match &slice.node { + ast::ExprKind::Slice { lower, upper, step } => { + for v in [lower.as_ref(), upper.as_ref(), step.as_ref()] + .iter() + .flatten() + { + self.unifier + .unify(self.primitives.int32, v.custom.unwrap())?; + } + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.unifier.unify(value.custom.unwrap(), list)?; + Ok(list) + } + ast::ExprKind::Constant { + value: ast::Constant::Int(val), + .. + } => { + // the index is a constant, so value can be a sequence (either list/tuple) + let ind: i32 = val + .try_into() + .map_err(|_| "Index must be int32".to_string())?; + let map = once((ind, ty)).collect(); + let seq = self.unifier.add_ty(TypeEnum::TSeq { map }); + self.unifier.unify(value.custom.unwrap(), seq)?; + Ok(ty) + } + _ => { + // the index is not a constant, so value can only be a list + self.unifier + .unify(slice.custom.unwrap(), self.primitives.int32)?; + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.unifier.unify(value.custom.unwrap(), list)?; + Ok(ty) + } + } + } + + fn infer_if_expr( + &mut self, + test: &ast::Expr>, + body: ast::Expr>, + orelse: ast::Expr>, + ) -> InferenceResult { + self.unifier + .unify(test.custom.unwrap(), self.primitives.bool)?; + self.unifier + .unify(body.custom.unwrap(), orelse.custom.unwrap())?; + Ok(body.custom.unwrap()) + } +} diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index 9c20f9182..d2da32ea3 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -52,24 +52,24 @@ type VarMap = Mapping; #[derive(Clone)] pub struct Call { - posargs: Vec, - kwargs: HashMap, - ret: Type, - fun: RefCell>, + pub posargs: Vec, + pub kwargs: HashMap, + pub ret: Type, + pub fun: RefCell>, } #[derive(Clone)] pub struct FuncArg { - name: String, - ty: Type, - is_optional: bool, + pub name: String, + pub ty: Type, + pub is_optional: bool, } #[derive(Clone)] pub struct FunSignature { - args: Vec, - ret: Type, - params: VarMap, + pub args: Vec, + pub ret: Type, + pub params: VarMap, } // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.