diff --git a/nac3core/src/typecheck/context/inference_context.rs b/nac3core/src/typecheck/context/inference_context.rs index c738eb930..c46dda374 100644 --- a/nac3core/src/typecheck/context/inference_context.rs +++ b/nac3core/src/typecheck/context/inference_context.rs @@ -56,8 +56,17 @@ impl<'a> InferenceContext<'a> { where F: FnOnce(&mut Self) -> R, { - self.stack.level += 1; + self.start_scope(); let result = f(self); + let poped_names = self.end_scope(); + (poped_names, result) + } + + pub fn start_scope(&mut self) { + self.stack.level += 1; + } + + pub fn end_scope(&mut self) -> Vec<(String, Type, Location)> { self.stack.level -= 1; let mut poped_names = Vec::new(); while !self.stack.sym_def.is_empty() { @@ -72,7 +81,7 @@ impl<'a> InferenceContext<'a> { break; } } - (poped_names, result) + poped_names } /// assign a type to an identifier. diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index 65ffac51b..73d9bc4c3 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -8,7 +8,12 @@ use crate::typecheck::primitives; use rustpython_parser::ast; use rustpython_parser::ast::fold::Fold; -impl<'a> ast::fold::Fold<()> for InferenceContext<'a> { +pub struct TypeInferencer<'a> { + pub ctx: InferenceContext<'a>, + pub error_stack: Vec<(String, ast::Location)> +} + +impl<'a> ast::fold::Fold<()> for TypeInferencer<'a> { type TargetU = Option; type Error = String; @@ -17,8 +22,8 @@ impl<'a> ast::fold::Fold<()> for InferenceContext<'a> { } fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { - // assert_eq!(node.custom, None); - + self.error_stack.push((node.node.name().into(), node.location)); + let expr = match &node.node { ast::ExprKind::ListComp { .. } => return self.fold_listcomp(node), _ => rustpython_parser::ast::fold::fold_expr(self, node)? @@ -28,7 +33,7 @@ impl<'a> ast::fold::Fold<()> for InferenceContext<'a> { // compute type info and store in the custom field custom: match &expr.node { ast::ExprKind::Constant {value, kind: _} => self.infer_constant(value), - ast::ExprKind::Name {id, ctx: _} => Ok(Some(self.resolve(id)?)), + ast::ExprKind::Name {id, ctx: _} => Ok(Some(self.ctx.resolve(id)?)), ast::ExprKind::List {elts, ctx: _} => self.infer_list(elts), ast::ExprKind::Tuple {elts, ctx: _} => self.infer_tuple(elts), ast::ExprKind::Attribute {value, attr, ctx: _} => self.infer_attribute(value, attr), @@ -49,27 +54,27 @@ impl<'a> ast::fold::Fold<()> for InferenceContext<'a> { } } -impl<'a> InferenceContext<'a> { +impl<'a> TypeInferencer<'a> { fn infer_constant(&self, constant: &ast::Constant) -> Result, String> { match constant { ast::Constant::Bool(_) => - Ok(Some(self.get_primitive(primitives::BOOL_TYPE))), + Ok(Some(self.ctx.get_primitive(primitives::BOOL_TYPE))), ast::Constant::Int(val) => { let int32: Result = val.try_into(); let int64: Result = val.try_into(); if int32.is_ok() { - Ok(Some(self.get_primitive(primitives::INT32_TYPE))) + Ok(Some(self.ctx.get_primitive(primitives::INT32_TYPE))) } else if int64.is_ok() { - Ok(Some(self.get_primitive(primitives::INT64_TYPE))) + Ok(Some(self.ctx.get_primitive(primitives::INT64_TYPE))) } else { Err("Integer out of bound".into()) } }, ast::Constant::Float(_) => - Ok(Some(self.get_primitive(primitives::FLOAT_TYPE))), + Ok(Some(self.ctx.get_primitive(primitives::FLOAT_TYPE))), ast::Constant::Tuple(vals) => { let result = vals @@ -134,16 +139,16 @@ impl<'a> InferenceContext<'a> { fn infer_attribute(&self, value: &ast::Expr>, attr: &str) -> Result, String> { let ty = value.custom.clone().ok_or_else(|| "no value".to_string())?; if let TypeEnum::TypeVariable(id) = ty.as_ref() { - let v = self.get_variable_def(*id); + let v = self.ctx.get_variable_def(*id); if v.bound.is_empty() { return Err("no fields on unbounded type variable".into()); } - let ty = v.bound[0].get_base(&self).and_then(|v| v.fields.get(attr)); + let ty = v.bound[0].get_base(&self.ctx).and_then(|v| v.fields.get(attr)); if ty.is_none() { return Err("unknown field".into()); } for x in v.bound[1..].iter() { - let ty1 = x.get_base(&self).and_then(|v| v.fields.get(attr)); + let ty1 = x.get_base(&self.ctx).and_then(|v| v.fields.get(attr)); if ty1 != ty { return Err("unknown field (type mismatch between variants)".into()); } @@ -151,7 +156,7 @@ impl<'a> InferenceContext<'a> { return Ok(Some(ty.unwrap().clone())); } - match ty.get_base(&self) { + match ty.get_base(&self.ctx) { Some(b) => match b.fields.get(attr) { Some(t) => Ok(Some(t.clone())), None => Err("no such field".into()), @@ -164,7 +169,7 @@ impl<'a> InferenceContext<'a> { assert_eq!(values.len(), 2); let left = values[0].custom.clone().ok_or_else(|| "no value".to_string())?; let right = values[1].custom.clone().ok_or_else(|| "no value".to_string())?; - let b = self.get_primitive(primitives::BOOL_TYPE); + let b = self.ctx.get_primitive(primitives::BOOL_TYPE); if left == b && right == b { Ok(Some(b)) } else { @@ -174,7 +179,7 @@ impl<'a> InferenceContext<'a> { fn infer_bin_ops(&self, left: &ast::Expr>, op: &ast::Operator, right: &ast::Expr>) -> Result, String> { inference_core::resolve_call( - &self, + &self.ctx, Some(left.custom.clone().ok_or_else(|| "no value".to_string())?), magic_methods::binop_name(op), &[right.custom.clone().ok_or_else(|| "no value".to_string())?]) @@ -182,13 +187,13 @@ impl<'a> InferenceContext<'a> { fn infer_unary_ops(&self, op: &ast::Unaryop, operand: &ast::Expr>) -> Result, String> { if let ast::Unaryop::Not = op { - if operand.custom == Some(self.get_primitive(primitives::BOOL_TYPE)) { - Ok(Some(self.get_primitive(primitives::BOOL_TYPE))) + if operand.custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE)) { + Ok(Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) } else { Err("logical not must be applied to bool".into()) } } else { - inference_core::resolve_call(&self, operand.custom.clone(), magic_methods::unaryop_name(op), &[]) + inference_core::resolve_call(&self.ctx, operand.custom.clone(), magic_methods::unaryop_name(op), &[]) } } @@ -196,9 +201,9 @@ impl<'a> InferenceContext<'a> { if left.custom.is_none() || (!comparators.iter().all(|x| x.custom.is_some())) { Err("comparison operands must have type".into()) } else { - let bool_type = Some(self.get_primitive(primitives::BOOL_TYPE)); + let bool_type = Some(self.ctx.get_primitive(primitives::BOOL_TYPE)); let ty_first = inference_core::resolve_call( - &self, + &self.ctx, Some(left.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?), magic_methods::comparison_name(&ops[0]).ok_or_else(|| "unsupported comparison".to_string())?, &[comparators[0].custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?])?; @@ -212,7 +217,7 @@ impl<'a> InferenceContext<'a> { .zip(comparators[1..].iter()) .zip(ops[1..].iter()) { let ty = inference_core::resolve_call( - &self, + &self.ctx, Some(a.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()), magic_methods::comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?, &[b.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()])?; @@ -229,14 +234,14 @@ impl<'a> InferenceContext<'a> { match &func.node { ast::ExprKind::Name {id, ctx: _} => inference_core::resolve_call( - &self, + &self.ctx, None, id, &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), ast::ExprKind::Attribute {value, attr, ctx: _} => inference_core::resolve_call( - &self, + &self.ctx, Some(value.custom.clone().ok_or_else(|| "no value".to_string())?), &attr, &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), @@ -252,7 +257,7 @@ impl<'a> InferenceContext<'a> { let val_type = value.custom.as_ref().ok_or_else(|| "no value".to_string())?.as_ref(); if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = val_type { if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { - let int32_type = self.get_primitive(primitives::INT32_TYPE); + let int32_type = self.ctx.get_primitive(primitives::INT32_TYPE); let l = lower.as_ref().map_or( Ok(&int32_type), |x| x.custom.as_ref().ok_or_else(|| "lower bound cannot be typped".to_string()))?; @@ -268,7 +273,7 @@ impl<'a> InferenceContext<'a> { } else { Err("slice must be int32 type".into()) } - } else if slice.custom == Some(self.get_primitive(primitives::INT32_TYPE)) { + } else if slice.custom == Some(self.ctx.get_primitive(primitives::INT32_TYPE)) { Ok(Some(ls[0].clone())) } else { Err("slice or index must be int32 type".into()) @@ -290,7 +295,7 @@ impl<'a> InferenceContext<'a> { } fn infer_if_expr(&self, test: &ast::Expr>, body: &ast::Expr>, orelse: &ast::Expr>) -> Result, String> { - if test.custom != Some(self.get_primitive(primitives::BOOL_TYPE)) { + if test.custom != Some(self.ctx.get_primitive(primitives::BOOL_TYPE)) { Err("test should be bool".into()) } else if body.custom == orelse.custom { Ok(body.custom.clone()) @@ -303,7 +308,7 @@ impl<'a> InferenceContext<'a> { if generators[0] .ifs .iter() - .all(|x| x.custom == Some(self.get_primitive(primitives::BOOL_TYPE))) { + .all(|x| x.custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) { Ok(Some(TypeEnum::ParametricType( primitives::LIST_TYPE, vec![elt.custom.clone().ok_or_else(|| "elements should have value".to_string())?]).into())) @@ -343,40 +348,43 @@ impl<'a> InferenceContext<'a> { .ok_or_else(|| "no value".to_string())? .as_ref() .clone() { - self.with_scope(|ctx| -> Result>, String> { - ctx.infer_simple_binding(&target, ls[0].clone())?; - let elt_folded = Box::new(ctx.fold_expr(*elt)?); - let target_folded = Box::new(ctx.fold_expr(*target)?); - let ifs_folded = ifs - .into_iter() - .map(|x| ctx.fold_expr(x)) - .collect::>>, _>>()?; - + + self.ctx.start_scope(); + self.infer_simple_binding(&target, ls[0].clone())?; + let elt_folded = Box::new(self.fold_expr(*elt)?); + let target_folded = Box::new(self.fold_expr(*target)?); + let ifs_folded = ifs + .into_iter() + .map(|x| self.fold_expr(x)) + .collect::>>, _>>()?; + + let result = if ifs_folded - .iter() - .all(|x| x.custom == Some(ctx.get_primitive(primitives::BOOL_TYPE))) { - Ok(ast::Expr { - location, - custom: Some(TypeEnum::ParametricType( - primitives::LIST_TYPE, - vec![elt_folded - .custom - .clone() - .ok_or_else(|| "elements cannot be typped".to_string())?]).into()), - node: ast::ExprKind::ListComp { - elt: elt_folded, - generators: vec![ast::Comprehension { - target: target_folded, - ifs: ifs_folded, - iter: iter_folded, - is_async - }] - } - }) - } else { - Err("test must be bool".into()) - } - }).1 + .iter() + .all(|x| x.custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) { + Ok(ast::Expr { + location, + custom: Some(TypeEnum::ParametricType( + primitives::LIST_TYPE, + vec![elt_folded + .custom + .clone() + .ok_or_else(|| "elements cannot be typped".to_string())?]).into()), + node: ast::ExprKind::ListComp { + elt: elt_folded, + generators: vec![ast::Comprehension { + target: target_folded, + ifs: ifs_folded, + iter: iter_folded, + is_async + }] + } + }) + } else { + Err("test must be bool".into()) + }; + self.ctx.end_scope(); + result } else { Err("iteration is supported for list only".into()) } @@ -390,10 +398,10 @@ impl<'a> InferenceContext<'a> { ast::ExprKind::Name {id, ctx: _} => { if id == "_" { Ok(()) - } else if self.defined(id) { + } else if self.ctx.defined(id) { Err("duplicated naming".into()) } else { - self.assign(id.clone(), ty, name.location)?; + self.ctx.assign(id.clone(), ty, name.location)?; Ok(()) } } @@ -415,14 +423,13 @@ impl<'a> InferenceContext<'a> { _ => Err("not supported".into()) } } -} -pub struct ExpressionInferencer<'a> { - pub ctx: InferenceContext<'a> -} -impl<'a> ExpressionInferencer<'a> { - pub fn fold_expr(&mut self, expr: ast::Expr) -> Result>, String> { - self.ctx.fold_expr(expr) + fn fold_expr(&mut self, node: ast::Expr<()>) -> Result>, String> { + let result = >::fold_expr(self, node); + if result.is_err() { + println!("{:?}", self.error_stack.pop().unwrap()); + } + result } } @@ -432,7 +439,7 @@ pub mod test { use rustpython_parser::ast::Expr; use super::*; - pub fn new_ctx<'a>() -> ExpressionInferencer<'a> { + pub fn new_ctx<'a>() -> TypeInferencer<'a> { struct S; impl SymbolResolver for S { fn get_symbol_location(&self, _str: &str) -> Option { None } @@ -440,7 +447,10 @@ pub mod test { fn get_symbol_value(&self, _str: &str) -> Option { None } } - ExpressionInferencer {ctx: InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3))} + TypeInferencer { + ctx: InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3)), + error_stack: Vec::new() + } } #[test] @@ -627,6 +637,8 @@ pub mod test { let ast10 = rustpython_parser::parser::parse_expression("4 if False and True else 8").unwrap(); let ast11 = rustpython_parser::parser::parse_expression("(1, 2, 3, 4)[1]").unwrap(); let ast12 = rustpython_parser::parser::parse_expression("(1, True, 3, False)[1]").unwrap(); + + let ast13 = rustpython_parser::parser::parse_expression("[1, True, 2]").unwrap(); let folded = inf.fold_expr(ast1).unwrap(); let folded_2 = inf.fold_expr(ast2).unwrap(); @@ -640,9 +652,11 @@ pub mod test { let folded_10 = inf.fold_expr(ast10).unwrap(); let folded_11 = inf.fold_expr(ast11).unwrap(); let folded_12 = inf.fold_expr(ast12).unwrap(); + println!("{:?}", folded.custom); println!("{:?}", folded_2.custom); + let folded_13 = inf.fold_expr(ast13); println!("{:?}", folded_3.custom); println!("{:?}", folded_4.custom); println!("{:?}", folded_5.custom);