diff --git a/nac3core/src/typecheck/context/inference_context.rs b/nac3core/src/typecheck/context/inference_context.rs index 93b978695..8cf29ea4c 100644 --- a/nac3core/src/typecheck/context/inference_context.rs +++ b/nac3core/src/typecheck/context/inference_context.rs @@ -6,12 +6,12 @@ use rustpython_parser::ast; use std::boxed::Box; use std::collections::HashMap; -struct ContextStack<'a> { +pub struct ContextStack { /// stack level, starts from 0 - level: u32, + pub level: u32, /// stack of symbol definitions containing (name, level) where `level` is the smallest level /// where the name is assigned a value - sym_def: Vec<(&'a str, u32)>, + pub sym_def: Vec<(String, u32)>, } pub struct InferenceContext<'a> { @@ -25,9 +25,9 @@ pub struct InferenceContext<'a> { /// identifier to (type, readable, location) mapping. /// an identifier might be defined earlier but has no value (for some code path), thus not /// readable. - sym_table: HashMap<&'a str, (Type, bool, Location)>, + pub sym_table: HashMap, /// stack - stack: ContextStack<'a>, + pub stack: ContextStack, } // non-trivial implementations here @@ -52,7 +52,7 @@ impl<'a> InferenceContext<'a> { /// execute the function with new scope. /// variable assignment would be limited within the scope (not readable outside), and type /// returns the list of variables assigned within the scope, and the result of the function - pub fn with_scope(&mut self, f: F) -> (Vec<(&'a str, Type, Location)>, R) + pub fn with_scope(&mut self, f: F) -> (Vec<(String, Type, Location)>, R) where F: FnOnce(&mut Self) -> R, { @@ -64,7 +64,7 @@ impl<'a> InferenceContext<'a> { let (_, level) = self.stack.sym_def.last().unwrap(); if *level > self.stack.level { let (name, _) = self.stack.sym_def.pop().unwrap(); - let (t, b, l) = self.sym_table.get_mut(name).unwrap(); + let (t, b, l) = self.sym_table.get_mut(&name).unwrap(); // set it to be unreadable *b = false; poped_names.push((name, t.clone(), *l)); @@ -77,8 +77,8 @@ impl<'a> InferenceContext<'a> { /// assign a type to an identifier. /// may return error if the identifier was defined but with different type - pub fn assign(&mut self, name: &'a str, ty: Type, loc: ast::Location) -> Result { - if let Some((t, x, _)) = self.sym_table.get_mut(name) { + pub fn assign(&mut self, name: String, ty: Type, loc: ast::Location) -> Result { + if let Some((t, x, _)) = self.sym_table.get_mut(&name) { if t == &ty { if !*x { self.stack.sym_def.push((name, self.stack.level)); @@ -89,7 +89,7 @@ impl<'a> InferenceContext<'a> { Err("different types".into()) } } else { - self.stack.sym_def.push((name, self.stack.level)); + self.stack.sym_def.push((name.clone(), self.stack.level)); self.sym_table.insert( name, (ty.clone(), true, Location::CodeRange(self.file, loc)), @@ -124,6 +124,11 @@ impl<'a> InferenceContext<'a> { self.resolver.get_symbol_location(name) } } + + /// check if an identifier is already defined + pub fn defined(&self, name: &String) -> bool { + self.sym_table.get(name).is_some() + } } // trivial getters: diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index bedf00db4..b3caeebfa 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -1,4 +1,5 @@ use std::convert::TryInto; +use std::fs::create_dir_all; use crate::typecheck::context::InferenceContext; use crate::typecheck::inference_core; @@ -6,11 +7,12 @@ use crate::typecheck::magic_methods; use crate::typecheck::typedef::{Type, TypeEnum}; use crate::typecheck::primitives; use rustpython_parser::ast; +use rustpython_parser::ast::fold::Fold; use super::inference_core::resolve_call; pub struct ExpressionTypeInferencer<'a> { - pub ctx: InferenceContext<'a> //FIXME: may need to remove this pub + pub ctx: InferenceContext<'a> } impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the function parameter for better error message? @@ -211,25 +213,6 @@ impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the functi } } - fn infer_slice(&self, lower: &Option>>>, upper: &Option>>>, step: &Option>>>) -> Result, String> { - let int32_type = self.ctx.get_primitive(primitives::INT32_TYPE); - let l = lower.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; - let u = upper.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; - let s = step.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; - - if l == &int32_type && u == &int32_type && s == &int32_type { - Ok(Some(self.ctx.get_primitive(primitives::SLICE_TYPE))) - } else { - Err("slice must be int32 type".into()) - } - } - fn infer_subscript(&self, value: &Box>>, slice: &Box>>) -> Result, String> { // let tt = value.custom.ok_or_else(|| "no value".to_string())?.as_ref(); @@ -239,13 +222,28 @@ impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the functi return Err("subscript is not supported for types other than list".into()); }; - if slice.custom == Some(self.ctx.get_primitive(primitives::SLICE_TYPE)) { - Ok(value.custom.clone()) - } else if slice.custom == Some(self.ctx.get_primitive(primitives::INT32_TYPE)) { - Ok(Some(t)) - } else { - Err("slice or index must be int32 type".into()) - } + if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { + let int32_type = self.ctx.get_primitive(primitives::INT32_TYPE); + let l = lower.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; + let u = upper.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; + let s = step.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; + + if l == &int32_type && u == &int32_type && s == &int32_type { + Ok(value.custom.clone()) + } else { + Err("slice must be int32 type".into()) + } + } else if slice.custom == Some(self.ctx.get_primitive(primitives::INT32_TYPE)) { + Ok(Some(t)) + } else { + Err("slice or index must be int32 type".into()) + } } fn infer_if_expr(&self, test: &Box>>, body: &Box>>, orelse: &Box>>) -> Result, String> { @@ -260,21 +258,56 @@ impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the functi } } - fn infer_simple_binding(&mut self, name: &'a ast::Expr>, ty: Type) -> Result<(), String> { + fn infer_list_comprehesion(&mut self, elt: &Box>>, generators: &Vec>>) -> Result, String> { + if generators[0] + .ifs + .iter() + .all(|x| x.custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) { + Ok(Some(TypeEnum::ParametricType( + primitives::LIST_TYPE, + vec![elt.custom.clone().ok_or_else(|| "elements should have value".to_string())?]).into())) + } else { + Err("test must be bool".into()) + } + } + + fn fold_comprehension_first(&mut self, node: ast::Comprehension>) -> Result>, String> { + Ok(ast::Comprehension { + target: node.target, + iter: Box::new(self.fold_expr(*node.iter)?), + ifs: node.ifs, + is_async: node.is_async + }) + } + + fn fold_comprehension_second(&mut self, node: ast::Comprehension>) -> Result>, String> { + Ok(ast::Comprehension { + target: Box::new(self.fold_expr(*node.target)?), + iter: node.iter, + ifs: node + .ifs + .into_iter() + .map(|x| self.fold_expr(x)) + .collect::, _>>()?, + is_async: node.is_async + }) + } + + fn infer_simple_binding(&mut self, name: &ast::Expr>, ty: Type) -> Result<(), String> { match &name.node { ast::ExprKind::Name {id, ctx: _} => { if id == "_" { Ok(()) - } else if self.ctx.defined(id.as_str()) { + } else if self.ctx.defined(id) { Err("duplicated naming".into()) } else { - self.ctx.assign(id.as_str(), ty, name.location)?; + self.ctx.assign(id.clone(), ty, name.location)?; Ok(()) } } - + ast::ExprKind::Tuple {elts, ctx: _} => { - if let TypeEnum::ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { + if let TypeEnum::ParametricType(primitives::TUPLE_TYPE, ls) = ty.as_ref() { if elts.len() == ls.len() { for (a, b) in elts.iter().zip(ls.iter()) { self.infer_simple_binding(a, b.clone())?; @@ -287,37 +320,12 @@ impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the functi Err("not supported".into()) } } - _ => Err("not supported".into()) } } - - fn infer_list_comprehesion(&mut self, elt: &Box>>, generators: &Vec>>) -> Result, String> { - if generators.len() != 1 { - Err("only 1 generator statement is supported".into()) - } else { - let gen = &generators[0]; - if gen.is_async { - Err("async is not supported".into()) - } else { - let iter_type = gen.iter.custom.as_ref().ok_or("no value".to_string())?.as_ref(); - if let TypeEnum::ParametricType(primitives::LIST_TYPE, ref ls) = iter_type { - self.ctx.with_scope(|x| { - // x.infer_simple_binding(&gen.target, ls[0].clone()); // FIXME: - Ok(None) - }).1 - } else { - Err("iteration is supported for list only".into()) - } - } - } - } - - - } -// REVIEW: field custom: from () to Option or just Option? + impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { type TargetU = Option; type Error = String; @@ -325,29 +333,55 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { fn map_user(&mut self, user: Option) -> Result { Ok(user) } - - // override the default fold_comprehension to avoid errors caused by folding locally bound variable - fn fold_comprehension(&mut self, node: ast::Comprehension>) -> Result, Self::Error> { - Ok(ast::Comprehension { - target: node.target, - iter: Box::new(self.fold_expr(*node.iter)?), - ifs: node.ifs, - is_async: node.is_async - }) - } - + fn fold_expr(&mut self, node: ast::Expr>) -> Result, Self::Error> { assert_eq!(node.custom, None); // NOTE: should pass let mut expr = node; - if let ast::Expr {location: _, custom: _, node: ast::ExprKind::ListComp {elt, generators } } = expr { - expr = ast::Expr { - location: expr.location, - custom: expr.custom, - node: ast::ExprKind::ListComp { - elt, - generators: generators.into_iter().map(|x| self.fold_comprehension(x)).collect::, _>>()? + + if let ast::Expr {location, custom, node: ast::ExprKind::ListComp {elt, generators } } = expr { + // is list comprehension, only fold generators which does not include unknown identifiers introduced by list comprehension + if generators.len() != 1 { + return Err("only 1 generator statement is supported".into()) + } + let generators_first_folded = generators + .into_iter() + .map(|x| self.fold_comprehension_first(x)).collect::, _>>()?; + + let gen = &generators_first_folded[0]; + let iter_type = gen.iter.custom.as_ref().ok_or("no value".to_string())?.as_ref(); + + if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = iter_type { + self.ctx.stack.level += 1; // FIXME: how to use with_scope?? + + self.infer_simple_binding(&gen.target, ls[0].clone())?; + expr = ast::Expr { + location, + custom, + node: ast::ExprKind::ListComp { + elt: Box::new(self.fold_expr(*elt)?), + generators: generators_first_folded + .into_iter() + .map(|x| self.fold_comprehension_second(x)) + .collect::, _>>()? + } + }; + + self.ctx.stack.level -= 1; + while !self.ctx.stack.sym_def.is_empty() { + let (_, level) = self.ctx.stack.sym_def.last().unwrap(); + if *level > self.ctx.stack.level { + let (name, _) = self.ctx.stack.sym_def.pop().unwrap(); + let (t, b, l) = self.ctx.sym_table.get_mut(&name).unwrap(); + // set it to be unreadable + *b = false; + } else { + break; + } } - }; + + } else { + return Err("iteration is supported for list only".into()); + } } else { // if not listcomp which requires special handling, skip current level, make sure child nodes have their type expr = ast::fold::fold_expr(self, expr)?; @@ -429,15 +463,15 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { node: expr.node }), - // REVIEW: add a new primitive type for slice and do type check of bounds here? + /* // REVIEW: add a new primitive type for slice and do type check of bounds here? ast::ExprKind::Slice {lower, upper, step } => Ok(ast::Expr { location: expr.location, custom: self.infer_slice(lower, upper, step)?, node: expr.node - }), + }), */ - ast::ExprKind::Subscript {value, slice, ctx} => + ast::ExprKind::Subscript {value, slice, ctx: _} => Ok(ast::Expr { location: expr.location, custom: self.infer_subscript(value, slice)?, @@ -451,15 +485,14 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { node: expr.node }), - ast::ExprKind::ListComp {elt, generators} => + ast::ExprKind::ListComp {elt, generators} => { + Ok(ast::Expr { location: expr.location, custom: self.infer_list_comprehesion(elt, generators)?, node: expr.node - }), - - - + }) + } _ => { // not supported Err("not supported yet".into())