From 4abe99f6b30ce17758e21f22bff063050009c6ca Mon Sep 17 00:00:00 2001 From: CrescentonC Date: Wed, 14 Jul 2021 17:06:00 +0800 Subject: [PATCH] refactor the using of rustpython fold again, now can use with_scope, need further testing --- .../typecheck/context/inference_context.rs | 8 +- .../src/typecheck/expression_inference.rs | 775 +++++++++--------- 2 files changed, 380 insertions(+), 403 deletions(-) diff --git a/nac3core/src/typecheck/context/inference_context.rs b/nac3core/src/typecheck/context/inference_context.rs index 8cf29ea4..c738eb93 100644 --- a/nac3core/src/typecheck/context/inference_context.rs +++ b/nac3core/src/typecheck/context/inference_context.rs @@ -8,10 +8,10 @@ use std::collections::HashMap; pub struct ContextStack { /// stack level, starts from 0 - pub level: u32, + level: u32, /// stack of symbol definitions containing (name, level) where `level` is the smallest level /// where the name is assigned a value - pub sym_def: Vec<(String, u32)>, + sym_def: Vec<(String, u32)>, } pub struct InferenceContext<'a> { @@ -25,9 +25,9 @@ pub struct InferenceContext<'a> { /// identifier to (type, readable, location) mapping. /// an identifier might be defined earlier but has no value (for some code path), thus not /// readable. - pub sym_table: HashMap, + sym_table: HashMap, /// stack - pub stack: ContextStack, + stack: ContextStack, } // non-trivial implementations here diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index b3caeebf..ac055a56 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -1,5 +1,4 @@ use std::convert::TryInto; -use std::fs::create_dir_all; use crate::typecheck::context::InferenceContext; use crate::typecheck::inference_core; @@ -9,411 +8,49 @@ use crate::typecheck::primitives; use rustpython_parser::ast; use rustpython_parser::ast::fold::Fold; -use super::inference_core::resolve_call; - -pub struct ExpressionTypeInferencer<'a> { - pub ctx: InferenceContext<'a> -} - -impl<'a> ExpressionTypeInferencer<'a> { // NOTE: add location here in the function parameter for better error message? - - fn infer_constant_val(&self, constant: &ast::Constant) -> Result, String> { - match constant { - ast::Constant::Bool(_) => - Ok(Some(self.ctx.get_primitive(primitives::BOOL_TYPE))), - - ast::Constant::Int(val) => { - let int32: Result = val.try_into(); - let int64: Result = val.try_into(); - - if int32.is_ok() { - Ok(Some(self.ctx.get_primitive(primitives::INT32_TYPE))) - } else if int64.is_ok() { - Ok(Some(self.ctx.get_primitive(primitives::INT64_TYPE))) - } else { - Err("Integer out of bound".into()) - } - }, - - ast::Constant::Float(_) => - Ok(Some(self.ctx.get_primitive(primitives::FLOAT_TYPE))), - - ast::Constant::Tuple(vals) => { - let result = vals - .into_iter() - .map(|x| self.infer_constant_val(x)) - .collect::>(); - - if result.iter().all(|x| x.is_ok()) { - Ok(Some(TypeEnum::ParametricType( - primitives::TUPLE_TYPE, - result - .into_iter() - .map(|x| x.unwrap().unwrap()) - .collect::>(), - ).into())) - } else { - Err("Some elements in tuple cannot be typed".into()) - } - } - - _ => Err("not supported".into()) - } - } - - - fn infer_list_val(&self, elts: &Vec>>) -> Result, String> { - if elts.is_empty() { - Ok(Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![TypeEnum::BotType.into()]).into())) - } else { - let types = elts - .iter() - .map(|x| &x.custom) - .collect::>(); - - if types.iter().all(|x| x.is_some()) { - let head = types.iter().next().unwrap(); // here unwrap alone should be fine after the previous check - if types.iter().all(|x| x.eq(head)) { - Ok(Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![(*head).clone().unwrap()]).into())) - } else { - Err("inhomogeneous list is not allowed".into()) - } - } else { - Err("list elements must have some type".into()) - } - } - } - - fn infer_tuple_val(&self, elts: &Vec>>) -> Result, String> { - let types = elts - .iter() - .map(|x| (x.custom).clone()) - .collect::>(); - - if types.iter().all(|x| x.is_some()) { - Ok(Some(TypeEnum::ParametricType( - primitives::TUPLE_TYPE, - types.into_iter().map(|x| x.unwrap()).collect()).into())) // unwrap alone should be fine after the previous check - } else { - Err("tuple elements must have some type".into()) - } - } - - fn infer_arrtibute(&self, value: &Box>>, attr: &str) -> Result, String> { - let ty = value.custom.clone().ok_or_else(|| "no value".to_string())?; - if let TypeEnum::TypeVariable(id) = ty.as_ref() { - let v = self.ctx.get_variable_def(*id); - if v.bound.is_empty() { - return Err("no fields on unbounded type variable".into()); - } - let ty = v.bound[0].get_base(&self.ctx).and_then(|v| v.fields.get(attr)); - if ty.is_none() { - return Err("unknown field".into()); - } - for x in v.bound[1..].iter() { - let ty1 = x.get_base(&self.ctx).and_then(|v| v.fields.get(attr)); - if ty1 != ty { - return Err("unknown field (type mismatch between variants)".into()); - } - } - return Ok(Some(ty.unwrap().clone())); - } - - match ty.get_base(&self.ctx) { - Some(b) => match b.fields.get(attr) { - Some(t) => Ok(Some(t.clone())), - None => Err("no such field".into()), - }, - None => Err("this object has no fields".into()), - } - } - - fn infer_bool_ops(&self, values: &Vec>>) -> Result, String> { - assert_eq!(values.len(), 2); - let left = values[0].custom.clone().ok_or_else(|| "no value".to_string())?; - let right = values[1].custom.clone().ok_or_else(|| "no value".to_string())?; - let b = self.ctx.get_primitive(primitives::BOOL_TYPE); - if left == b && right == b { - Ok(Some(b)) - } else { - Err("bool operands must be bool".to_string()) - } - } - - fn _infer_bin_ops(&self, _left: &Box>>, _op: &ast::Operator, _right: &Box>>) -> Result, String> { - Err("no need this function".into()) - } - - fn infer_unary_ops(&self, op: &ast::Unaryop, operand: &Box>>) -> Result, String> { - if let ast::Unaryop::Not = op { - if (**operand).custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE)) { - Ok(Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) - } else { - Err("logical not must be applied to bool".into()) - } - } else { - inference_core::resolve_call(&self.ctx, (**operand).custom.clone(), magic_methods::unaryop_name(op), &[]) - } - } - - fn infer_compare(&self, left: &Box>>, ops: &Vec, comparators: &Vec>>) -> Result, String> { - assert!(comparators.len() > 0); - if left.custom.is_none() || (!comparators.iter().all(|x| x.custom.is_some())) { - Err("comparison operands must have type".into()) - } else { - let bool_type = Some(self.ctx.get_primitive(primitives::BOOL_TYPE)); - let ty_first = resolve_call( - &self.ctx, - Some(left.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()), - magic_methods::comparison_name(&ops[0]).ok_or_else(|| "unsupported comparison".to_string())?, - &[comparators[0].custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?])?; - if ty_first != bool_type { - return Err("comparison result must be boolean".into()); - } - - for ((a, b), op) - in comparators[..(comparators.len() - 1)] - .iter() - .zip(comparators[1..].iter()) - .zip(ops[1..].iter()) { - let ty = resolve_call( - &self.ctx, - Some(a.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()), - magic_methods::comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?, - &[b.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()])?; - if ty != bool_type { - return Err("comparison result must be boolean".into()); - } - } - Ok(bool_type) - } - } - - fn infer_call(&self, func: &Box>>, args: &Vec>>, _keywords: &Vec>>) -> Result, String> { - if args.iter().all(|x| x.custom.is_some()) { - match &func.node { - ast::ExprKind::Name {id, ctx: _} - => resolve_call( - &self.ctx, - None, - id, - &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), - - ast::ExprKind::Attribute {value, attr, ctx: _} - => resolve_call( - &self.ctx, - Some(value.custom.clone().ok_or_else(|| "no value".to_string())?), - &attr, - &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), - - _ => Err("not supported".into()) - } - } else { - Err("function params must have type".into()) - } - } - - fn infer_subscript(&self, value: &Box>>, slice: &Box>>) -> Result, String> { - // let tt = value.custom.ok_or_else(|| "no value".to_string())?.as_ref(); - - let t = if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = value.custom.as_ref().ok_or_else(|| "no value".to_string())?.as_ref() { - ls[0].clone() - } else { - return Err("subscript is not supported for types other than list".into()); - }; - - if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { - let int32_type = self.ctx.get_primitive(primitives::INT32_TYPE); - let l = lower.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; - let u = upper.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; - let s = step.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; - - if l == &int32_type && u == &int32_type && s == &int32_type { - Ok(value.custom.clone()) - } else { - Err("slice must be int32 type".into()) - } - } else if slice.custom == Some(self.ctx.get_primitive(primitives::INT32_TYPE)) { - Ok(Some(t)) - } else { - Err("slice or index must be int32 type".into()) - } - } - - fn infer_if_expr(&self, test: &Box>>, body: &Box>>, orelse: &Box>>) -> Result, String> { - if test.custom != Some(self.ctx.get_primitive(primitives::BOOL_TYPE)) { - Err("test should be bool".into()) - } else { - if body.custom == orelse.custom { - Ok(body.custom.clone()) - } else { - Err("divergent type at if expression".into()) - } - } - } - - fn infer_list_comprehesion(&mut self, elt: &Box>>, generators: &Vec>>) -> Result, String> { - if generators[0] - .ifs - .iter() - .all(|x| x.custom == Some(self.ctx.get_primitive(primitives::BOOL_TYPE))) { - Ok(Some(TypeEnum::ParametricType( - primitives::LIST_TYPE, - vec![elt.custom.clone().ok_or_else(|| "elements should have value".to_string())?]).into())) - } else { - Err("test must be bool".into()) - } - } - - fn fold_comprehension_first(&mut self, node: ast::Comprehension>) -> Result>, String> { - Ok(ast::Comprehension { - target: node.target, - iter: Box::new(self.fold_expr(*node.iter)?), - ifs: node.ifs, - is_async: node.is_async - }) - } - - fn fold_comprehension_second(&mut self, node: ast::Comprehension>) -> Result>, String> { - Ok(ast::Comprehension { - target: Box::new(self.fold_expr(*node.target)?), - iter: node.iter, - ifs: node - .ifs - .into_iter() - .map(|x| self.fold_expr(x)) - .collect::, _>>()?, - is_async: node.is_async - }) - } - - fn infer_simple_binding(&mut self, name: &ast::Expr>, ty: Type) -> Result<(), String> { - match &name.node { - ast::ExprKind::Name {id, ctx: _} => { - if id == "_" { - Ok(()) - } else if self.ctx.defined(id) { - Err("duplicated naming".into()) - } else { - self.ctx.assign(id.clone(), ty, name.location)?; - Ok(()) - } - } - - ast::ExprKind::Tuple {elts, ctx: _} => { - if let TypeEnum::ParametricType(primitives::TUPLE_TYPE, ls) = ty.as_ref() { - if elts.len() == ls.len() { - for (a, b) in elts.iter().zip(ls.iter()) { - self.infer_simple_binding(a, b.clone())?; - } - Ok(()) - } else { - Err("different length".into()) - } - } else { - Err("not supported".into()) - } - } - _ => Err("not supported".into()) - } - } -} - - -impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { +impl<'a> ast::fold::Fold> for InferenceContext<'a> { type TargetU = Option; type Error = String; fn map_user(&mut self, user: Option) -> Result { Ok(user) } - + fn fold_expr(&mut self, node: ast::Expr>) -> Result, Self::Error> { - assert_eq!(node.custom, None); // NOTE: should pass + assert_eq!(node.custom, None); let mut expr = node; - - if let ast::Expr {location, custom, node: ast::ExprKind::ListComp {elt, generators } } = expr { - // is list comprehension, only fold generators which does not include unknown identifiers introduced by list comprehension - if generators.len() != 1 { - return Err("only 1 generator statement is supported".into()) - } - let generators_first_folded = generators - .into_iter() - .map(|x| self.fold_comprehension_first(x)).collect::, _>>()?; - let gen = &generators_first_folded[0]; - let iter_type = gen.iter.custom.as_ref().ok_or("no value".to_string())?.as_ref(); - - if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = iter_type { - self.ctx.stack.level += 1; // FIXME: how to use with_scope?? - - self.infer_simple_binding(&gen.target, ls[0].clone())?; - expr = ast::Expr { - location, - custom, - node: ast::ExprKind::ListComp { - elt: Box::new(self.fold_expr(*elt)?), - generators: generators_first_folded - .into_iter() - .map(|x| self.fold_comprehension_second(x)) - .collect::, _>>()? - } - }; - - self.ctx.stack.level -= 1; - while !self.ctx.stack.sym_def.is_empty() { - let (_, level) = self.ctx.stack.sym_def.last().unwrap(); - if *level > self.ctx.stack.level { - let (name, _) = self.ctx.stack.sym_def.pop().unwrap(); - let (t, b, l) = self.ctx.sym_table.get_mut(&name).unwrap(); - // set it to be unreadable - *b = false; - } else { - break; - } - } - - } else { - return Err("iteration is supported for list only".into()); - } - } else { - // if not listcomp which requires special handling, skip current level, make sure child nodes have their type - expr = ast::fold::fold_expr(self, expr)?; - } + match &expr.node { + ast::ExprKind::ListComp { .. } => expr = self.prefold_list_comprehension(expr)?, + _ => expr = rustpython_parser::ast::fold::fold_expr(self, expr)? + }; match &expr.node { ast::ExprKind::Constant {value, kind: _} => Ok(ast::Expr { location: expr.location, - custom: self.infer_constant_val(value)?, + custom: self.infer_constant(value)?, node: expr.node }), ast::ExprKind::Name {id, ctx: _} => Ok(ast::Expr { location: expr.location, - custom: Some(self.ctx.resolve(id)?), + custom: Some(self.resolve(id)?), node: expr.node }), - ast::ExprKind::List {elts, ctx: _} => { + ast::ExprKind::List {elts, ctx: _} => Ok(ast::Expr { location: expr.location, - custom: self.infer_list_val(elts)?, + custom: self.infer_list(elts)?, node: expr.node - }) - } + }), ast::ExprKind::Tuple {elts, ctx: _} => Ok(ast::Expr { location: expr.location, - custom: self.infer_tuple_val(elts)?, + custom: self.infer_tuple(elts)?, node: expr.node }), @@ -435,7 +72,7 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { Ok(ast::Expr { location: expr.location, custom: inference_core::resolve_call( - &self.ctx, + &self, Some(left.custom.clone().ok_or_else(|| "no value".to_string())?), magic_methods::binop_name(op), &[right.custom.clone().ok_or_else(|| "no value".to_string())?])?, @@ -463,14 +100,6 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { node: expr.node }), - /* // REVIEW: add a new primitive type for slice and do type check of bounds here? - ast::ExprKind::Slice {lower, upper, step } => - Ok(ast::Expr { - location: expr.location, - custom: self.infer_slice(lower, upper, step)?, - node: expr.node - }), */ - ast::ExprKind::Subscript {value, slice, ctx: _} => Ok(ast::Expr { location: expr.location, @@ -485,20 +114,369 @@ impl<'a> ast::fold::Fold> for ExpressionTypeInferencer<'a> { node: expr.node }), - ast::ExprKind::ListComp {elt, generators} => { - + ast::ExprKind::ListComp {elt, generators} => Ok(ast::Expr { location: expr.location, custom: self.infer_list_comprehesion(elt, generators)?, node: expr.node - }) + }), + + // not supported + _ => Err("not supported yet".into()) + } + } +} + +impl<'a> InferenceContext<'a> { + fn infer_constant(&self, constant: &ast::Constant) -> Result, String> { + match constant { + ast::Constant::Bool(_) => + Ok(Some(self.get_primitive(primitives::BOOL_TYPE))), + + ast::Constant::Int(val) => { + let int32: Result = val.try_into(); + let int64: Result = val.try_into(); + + if int32.is_ok() { + Ok(Some(self.get_primitive(primitives::INT32_TYPE))) + } else if int64.is_ok() { + Ok(Some(self.get_primitive(primitives::INT64_TYPE))) + } else { + Err("Integer out of bound".into()) + } + }, + + ast::Constant::Float(_) => + Ok(Some(self.get_primitive(primitives::FLOAT_TYPE))), + + ast::Constant::Tuple(vals) => { + let result = vals + .into_iter() + .map(|x| self.infer_constant(x)) + .collect::>(); + + if result.iter().all(|x| x.is_ok()) { + Ok(Some(TypeEnum::ParametricType( + primitives::TUPLE_TYPE, + result + .into_iter() + .map(|x| x.unwrap().unwrap()) + .collect::>(), + ).into())) + } else { + Err("Some elements in tuple cannot be typed".into()) + } } - _ => { // not supported - Err("not supported yet".into()) + _ => Err("not supported".into()) + } + } + + fn infer_list(&self, elts: &Vec>>) -> Result, String> { + if elts.is_empty() { + Ok(Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![TypeEnum::BotType.into()]).into())) + } else { + let types = elts + .iter() + .map(|x| &x.custom) + .collect::>(); + + if types.iter().all(|x| x.is_some()) { + let head = types.iter().next().unwrap(); // here unwrap alone should be fine after the previous check + if types.iter().all(|x| x.eq(head)) { + Ok(Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![(*head).clone().unwrap()]).into())) + } else { + Err("inhomogeneous list is not allowed".into()) + } + } else { + Err("list elements must have some type".into()) } } } + + fn infer_tuple(&self, elts: &Vec>>) -> Result, String> { + let types = elts + .iter() + .map(|x| (x.custom).clone()) + .collect::>(); + + if types.iter().all(|x| x.is_some()) { + Ok(Some(TypeEnum::ParametricType( + primitives::TUPLE_TYPE, + types.into_iter().map(|x| x.unwrap()).collect()).into())) // unwrap alone should be fine after the previous check + } else { + Err("tuple elements must have some type".into()) + } + } + + fn infer_arrtibute(&self, value: &Box>>, attr: &str) -> Result, String> { + let ty = value.custom.clone().ok_or_else(|| "no value".to_string())?; + if let TypeEnum::TypeVariable(id) = ty.as_ref() { + let v = self.get_variable_def(*id); + if v.bound.is_empty() { + return Err("no fields on unbounded type variable".into()); + } + let ty = v.bound[0].get_base(&self).and_then(|v| v.fields.get(attr)); + if ty.is_none() { + return Err("unknown field".into()); + } + for x in v.bound[1..].iter() { + let ty1 = x.get_base(&self).and_then(|v| v.fields.get(attr)); + if ty1 != ty { + return Err("unknown field (type mismatch between variants)".into()); + } + } + return Ok(Some(ty.unwrap().clone())); + } + + match ty.get_base(&self) { + Some(b) => match b.fields.get(attr) { + Some(t) => Ok(Some(t.clone())), + None => Err("no such field".into()), + }, + None => Err("this object has no fields".into()), + } + } + + fn infer_bool_ops(&self, values: &Vec>>) -> Result, String> { + assert_eq!(values.len(), 2); + let left = values[0].custom.clone().ok_or_else(|| "no value".to_string())?; + let right = values[1].custom.clone().ok_or_else(|| "no value".to_string())?; + let b = self.get_primitive(primitives::BOOL_TYPE); + if left == b && right == b { + Ok(Some(b)) + } else { + Err("bool operands must be bool".to_string()) + } + } + + fn _infer_bin_ops(&self, _left: &Box>>, _op: &ast::Operator, _right: &Box>>) -> Result, String> { + Err("no need this function".into()) + } + + fn infer_unary_ops(&self, op: &ast::Unaryop, operand: &Box>>) -> Result, String> { + if let ast::Unaryop::Not = op { + if (**operand).custom == Some(self.get_primitive(primitives::BOOL_TYPE)) { + Ok(Some(self.get_primitive(primitives::BOOL_TYPE))) + } else { + Err("logical not must be applied to bool".into()) + } + } else { + inference_core::resolve_call(&self, (**operand).custom.clone(), magic_methods::unaryop_name(op), &[]) + } + } + + fn infer_compare(&self, left: &Box>>, ops: &Vec, comparators: &Vec>>) -> Result, String> { + assert!(comparators.len() > 0); + if left.custom.is_none() || (!comparators.iter().all(|x| x.custom.is_some())) { + Err("comparison operands must have type".into()) + } else { + let bool_type = Some(self.get_primitive(primitives::BOOL_TYPE)); + let ty_first = inference_core::resolve_call( + &self, + Some(left.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()), + magic_methods::comparison_name(&ops[0]).ok_or_else(|| "unsupported comparison".to_string())?, + &[comparators[0].custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?])?; + if ty_first != bool_type { + return Err("comparison result must be boolean".into()); + } + + for ((a, b), op) + in comparators[..(comparators.len() - 1)] + .iter() + .zip(comparators[1..].iter()) + .zip(ops[1..].iter()) { + let ty = inference_core::resolve_call( + &self, + Some(a.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()), + magic_methods::comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?, + &[b.custom.clone().ok_or_else(|| "comparator must be able to be typed".to_string())?.clone()])?; + if ty != bool_type { + return Err("comparison result must be boolean".into()); + } + } + Ok(bool_type) + } + } + + fn infer_call(&self, func: &Box>>, args: &Vec>>, _keywords: &Vec>>) -> Result, String> { + if args.iter().all(|x| x.custom.is_some()) { + match &func.node { + ast::ExprKind::Name {id, ctx: _} + => inference_core::resolve_call( + &self, + None, + id, + &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), + + ast::ExprKind::Attribute {value, attr, ctx: _} + => inference_core::resolve_call( + &self, + Some(value.custom.clone().ok_or_else(|| "no value".to_string())?), + &attr, + &args.iter().map(|x| x.custom.clone().unwrap()).collect::>()), + + _ => Err("not supported".into()) + } + } else { + Err("function params must have type".into()) + } + } + + fn infer_subscript(&self, value: &Box>>, slice: &Box>>) -> Result, String> { + let t = if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = value.custom.as_ref().ok_or_else(|| "no value".to_string())?.as_ref() { + ls[0].clone() + } else { + return Err("subscript is not supported for types other than list".into()); + }; + + if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { + let int32_type = self.get_primitive(primitives::INT32_TYPE); + let l = lower.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; + let u = upper.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; + let s = step.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; + + if l == &int32_type && u == &int32_type && s == &int32_type { + Ok(value.custom.clone()) + } else { + Err("slice must be int32 type".into()) + } + } else if slice.custom == Some(self.get_primitive(primitives::INT32_TYPE)) { + Ok(Some(t)) + } else { + Err("slice or index must be int32 type".into()) + } + } + + fn infer_if_expr(&self, test: &Box>>, body: &Box>>, orelse: &Box>>) -> Result, String> { + if test.custom != Some(self.get_primitive(primitives::BOOL_TYPE)) { + Err("test should be bool".into()) + } else { + if body.custom == orelse.custom { + Ok(body.custom.clone()) + } else { + Err("divergent type at if expression".into()) + } + } + } + + fn infer_list_comprehesion(&self, elt: &Box>>, generators: &Vec>>) -> Result, String> { + if generators[0] + .ifs + .iter() + .all(|x| x.custom == Some(self.get_primitive(primitives::BOOL_TYPE))) { + Ok(Some(TypeEnum::ParametricType( + primitives::LIST_TYPE, + vec![elt.custom.clone().ok_or_else(|| "elements should have value".to_string())?]).into())) + } else { + Err("test must be bool".into()) + } + } + + fn prefold_list_comprehension(&mut self, expr: ast::Expr>) -> Result>, String> { + if let ast::Expr { + location, + custom, + node: ast::ExprKind::ListComp { + elt, + generators}} = expr { + // if is list comprehension, need special pre-fold + if generators.len() != 1 { + return Err("only 1 generator statement is supported".into()); + } + if generators[0].is_async { + return Err("async is not supported".into()); + } + + // fold iter first since it does not contain new identifiers + let generators_first_folded = generators + .into_iter() + .map(|x| -> Result>, String> {Ok(ast::Comprehension { + target: x.target, + iter: Box::new(self.fold_expr(*x.iter)?), // fold here + ifs: x.ifs, + is_async: x.is_async + })}) + .collect::, _>>()?; + + if let TypeEnum::ParametricType( + primitives::LIST_TYPE, + ls) = generators_first_folded[0] + .iter + .custom + .as_ref() + .ok_or_else(|| "no value".to_string())? + .as_ref() + .clone() { + self.with_scope(|ctx| -> Result>, String> { + ctx.infer_simple_binding( + &generators_first_folded[0].target, + ls[0].clone())?; + Ok(ast::Expr { + location, + custom, + node: ast::ExprKind::ListComp { // now fold things with new name + elt: Box::new(ctx.fold_expr(*elt)?), + generators: generators_first_folded + .into_iter() + .map(|x| -> Result>, String> {Ok(ast::Comprehension { + target: Box::new(ctx.fold_expr(*x.target)?), + iter: x.iter, + ifs: x + .ifs + .into_iter() + .map(|x| ctx.fold_expr(x)) + .collect::, _>>()?, + is_async: x.is_async + })}) + .collect::, _>>()? + } + }) + }).1 + } else { + Err("iteration is supported for list only".into()) + } + } else { + panic!("this function is for list comprehensions only!"); + } + } + + fn infer_simple_binding(&mut self, name: &ast::Expr>, ty: Type) -> Result<(), String> { + match &name.node { + ast::ExprKind::Name {id, ctx: _} => { + if id == "_" { + Ok(()) + } else if self.defined(id) { + Err("duplicated naming".into()) + } else { + self.assign(id.clone(), ty, name.location)?; + Ok(()) + } + } + + ast::ExprKind::Tuple {elts, ctx: _} => { + if let TypeEnum::ParametricType(primitives::TUPLE_TYPE, ls) = ty.as_ref() { + if elts.len() == ls.len() { + for (a, b) in elts.iter().zip(ls.iter()) { + self.infer_simple_binding(a, b.clone())?; + } + Ok(()) + } else { + Err("different length".into()) + } + } else { + Err("not supported".into()) + } + } + _ => Err("not supported".into()) + } + } } pub mod test { @@ -507,7 +485,7 @@ pub mod test { use rustpython_parser::ast::{self, Expr, fold::Fold}; use super::*; - pub fn new_ctx<'a>() -> ExpressionTypeInferencer<'a>{ + pub fn new_ctx<'a>() -> InferenceContext<'a>{ struct S; impl SymbolResolver for S { @@ -524,9 +502,7 @@ pub mod test { } } - ExpressionTypeInferencer { - ctx: InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3)), - } + InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3)) } @@ -552,7 +528,7 @@ pub mod test { new_ast, Expr { location: location, - custom: Some(inferencer.ctx.get_primitive(primitives::INT64_TYPE)), + custom: Some(inferencer.get_primitive(primitives::INT64_TYPE)), node: ast::ExprKind::Constant { value: ast::Constant::Int(num.into()), kind: None, @@ -598,13 +574,13 @@ pub mod test { new_ast, Expr { location, - custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.ctx.get_primitive(primitives::INT32_TYPE).into()]).into()), + custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.get_primitive(primitives::INT32_TYPE).into()]).into()), node: ast::ExprKind::List { ctx: ast::ExprContext::Load, elts: vec![ Expr { location, - custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)), + custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)), node: ast::ExprKind::Constant { value: ast::Constant::Int(1.into()), kind: None, @@ -613,7 +589,8 @@ pub mod test { Expr { location, - custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)), + custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)), + // custom: None, node: ast::ExprKind::Constant { value: ast::Constant::Int(2.into()), kind: None, @@ -625,4 +602,4 @@ pub mod test { ); } -} +} \ No newline at end of file