use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; use itertools::izip; use rustpython_parser::ast::{ self, fold::{self, Fold}, Arguments, Comprehension, ExprKind, Located, Location, StrRef, }; #[cfg(test)] mod test; #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)] pub struct CodeLocation { row: usize, col: usize, } impl From for CodeLocation { fn from(loc: Location) -> CodeLocation { CodeLocation { row: loc.row(), col: loc.column() } } } #[derive(Clone, Copy)] pub struct PrimitiveStore { pub int32: Type, pub int64: Type, pub float: Type, pub bool: Type, pub none: Type, } pub struct FunctionData { pub resolver: Arc>, pub return_type: Option, pub bound_variables: Vec, } pub struct Inferencer<'a> { pub top_level: &'a TopLevelContext, pub defined_identifiers: HashSet, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, pub virtual_checks: &'a mut Vec<(Type, Type)>, pub variable_mapping: HashMap, pub calls: &'a mut HashMap, } struct NaiveFolder(); impl fold::Fold<()> for NaiveFolder { type TargetU = Option; type Error = String; fn map_user(&mut self, _: ()) -> Result { Ok(None) } } impl<'a> fold::Fold<()> for Inferencer<'a> { type TargetU = Option; type Error = String; fn map_user(&mut self, _: ()) -> Result { Ok(None) } fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result, Self::Error> { let stmt = match node.node { // we don't want fold over type annotation ast::StmtKind::AnnAssign { target, annotation, value, simple } => { self.infer_pattern(&target)?; let target = Box::new(self.fold_expr(*target)?); let value = if let Some(v) = value { let ty = Box::new(self.fold_expr(*v)?); self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?; Some(ty) } else { return Err(format!( "declaration without definition is not yet supported, at {}", node.location )); }; let top_level_defs = self.top_level.definitions.read(); let annotation_type = self.function_data.resolver.parse_type_annotation( top_level_defs.as_slice(), self.unifier, &self.primitives, annotation.as_ref(), )?; self.unify(annotation_type, target.custom.unwrap(), &node.location)?; let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); Located { location: node.location, custom: None, node: ast::StmtKind::AnnAssign { target, annotation, value, simple }, } } ast::StmtKind::For { ref target, .. } => { self.infer_pattern(target)?; fold::fold_stmt(self, node)? } ast::StmtKind::Assign { ref targets, .. } => { for target in targets { self.infer_pattern(target)?; } fold::fold_stmt(self, node)? } _ => fold::fold_stmt(self, node)?, }; match &stmt.node { ast::StmtKind::For { target, iter, .. } => { let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); self.unify(list, iter.custom.unwrap(), &iter.location)?; } ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; } ast::StmtKind::Assign { targets, value, .. } => { for target in targets.iter() { self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?; } } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} ast::StmtKind::Break | ast::StmtKind::Continue | ast::StmtKind::Pass => {} ast::StmtKind::Return { value } => match (value, self.function_data.return_type) { (Some(v), Some(v1)) => { self.unify(v.custom.unwrap(), v1, &v.location)?; } (Some(_), None) => { return Err("Unexpected return value".to_string()); } (None, Some(_)) => { return Err("Expected return value".to_string()); } (None, None) => {} }, _ => return Err("Unsupported statement type".to_string()), }; Ok(stmt) } fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { let expr = match node.node { ast::ExprKind::Call { func, args, keywords } => { return self.fold_call(node.location, *func, args, keywords); } ast::ExprKind::Lambda { args, body } => { return self.fold_lambda(node.location, *args, *body); } ast::ExprKind::ListComp { elt, generators } => { return self.fold_listcomp(node.location, *elt, generators); } _ => fold::fold_expr(self, node)?, }; let custom = match &expr.node { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { if self.function_data.resolver.get_identifier_def(*id).is_some() { self.defined_identifiers.insert(id.clone()); } else { return Err(format!( "unknown identifier {} (use before def?) at {}", id, expr.location )); } } Some(self.infer_identifier(*id)?) } ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ast::ExprKind::Attribute { value, attr, ctx: _ } => { Some(self.infer_attribute(value, *attr)?) } ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BinOp { left, op, right } => { Some(self.infer_bin_ops(expr.location, left, op, right)?) } ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), ast::ExprKind::Compare { left, ops, comparators } => { Some(self.infer_compare(left, ops, comparators)?) } ast::ExprKind::Subscript { value, slice, .. } => { Some(self.infer_subscript(value.as_ref(), slice.as_ref())?) } ast::ExprKind::IfExp { test, body, orelse } => { Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) } ast::ExprKind::ListComp { .. } | ast::ExprKind::Lambda { .. } | ast::ExprKind::Call { .. } => expr.custom, // already computed ast::ExprKind::Slice { .. } => None, // we don't need it for slice _ => return Err("not supported yet".into()), }; Ok(ast::Expr { custom, location: expr.location, node: expr.node }) } } type InferenceResult = Result; impl<'a> Inferencer<'a> { /// Constrain a <: b /// Currently implemented as unification fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) } fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) } fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> { match &pattern.node { ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { self.defined_identifiers.insert(id.clone()); } Ok(()) } ExprKind::Tuple { elts, .. } => { for elt in elts.iter() { self.infer_pattern(elt)?; } Ok(()) } _ => Ok(()), } } fn build_method_call( &mut self, location: Location, method: StrRef, obj: Type, params: Vec, ret: Option, ) -> InferenceResult { if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { if class_params.borrow().is_empty() { if let Some(ty) = fields.borrow().get(&method) { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) { let sign = sign.borrow(); if sign.vars.is_empty() { if let Some(ret) = ret { self.unifier.unify(sign.ret, ret).unwrap(); } return Ok(sign.ret); } } } } } let ret = ret.unwrap_or_else(|| self.unifier.get_fresh_var().0); let call = self.unifier.add_call(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None), }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let fields = once((method, call)).collect(); let record = self.unifier.add_record(fields); self.constrain(obj, record, &location)?; Ok(ret) } fn fold_lambda( &mut self, location: Location, args: Arguments, body: ast::Expr<()>, ) -> Result>, String> { if !args.posonlyargs.is_empty() || args.vararg.is_some() || !args.kwonlyargs.is_empty() || args.kwarg.is_some() || !args.defaults.is_empty() { // actually I'm not sure whether programs violating this is a valid python program. return Err( "We only support positional or keyword arguments without defaults for lambdas." .to_string(), ); } let mut defined_identifiers = self.defined_identifiers.clone(); for arg in args.args.iter() { let name = &arg.node.arg; if !defined_identifiers.contains(name) { defined_identifiers.insert(name.clone()); } } let fn_args: Vec<_> = args .args .iter() .map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0)) .collect(); let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().cloned()); let ret = self.unifier.get_fresh_var().0; let mut new_context = Inferencer { function_data: self.function_data, unifier: self.unifier, primitives: self.primitives, virtual_checks: self.virtual_checks, calls: self.calls, top_level: self.top_level, defined_identifiers, variable_mapping, }; let fun = FunSignature { args: fn_args .iter() .map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None }) .collect(), ret, vars: Default::default(), }; let body = new_context.fold_expr(body)?; new_context.unify(fun.ret, body.custom.unwrap(), &location)?; let mut args = new_context.fold_arguments(args)?; for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) { assert_eq!(&arg.node.arg, name); arg.custom = Some(*ty); } Ok(Located { location, node: ExprKind::Lambda { args: args.into(), body: body.into() }, custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))), }) } fn fold_listcomp( &mut self, location: Location, elt: ast::Expr<()>, mut generators: Vec, ) -> Result>, String> { if generators.len() != 1 { return Err( "Only 1 generator statement for list comprehension is supported.".to_string() ); } let variable_mapping = self.variable_mapping.clone(); let defined_identifiers = self.defined_identifiers.clone(); let mut new_context = Inferencer { function_data: self.function_data, unifier: self.unifier, virtual_checks: self.virtual_checks, top_level: self.top_level, variable_mapping, primitives: self.primitives, calls: self.calls, defined_identifiers, }; let generator = generators.pop().unwrap(); if generator.is_async { return Err("Async iterator not supported.".to_string()); } new_context.infer_pattern(&generator.target)?; let elt = new_context.fold_expr(elt)?; let target = new_context.fold_expr(*generator.target)?; let iter = new_context.fold_expr(*generator.iter)?; let ifs: Vec<_> = generator .ifs .into_iter() .map(|v| new_context.fold_expr(v)) .collect::>()?; // iter should be a list of targets... // actually it should be an iterator of targets, but we don't have iter type for now let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); new_context.unify(iter.custom.unwrap(), list, &iter.location)?; // if conditions should be bool for v in ifs.iter() { new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?; } Ok(Located { location, custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })), node: ExprKind::ListComp { elt: Box::new(elt), generators: vec![ast::Comprehension { target: Box::new(target), iter: Box::new(iter), ifs, is_async: false, }], }, }) } fn fold_call( &mut self, location: Location, func: ast::Expr<()>, mut args: Vec>, keywords: Vec>, ) -> Result>, String> { let func = if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } = func { // handle special functions that cannot be typed in the usual way... if id == "virtual".into() { if args.is_empty() || args.len() > 2 || !keywords.is_empty() { return Err( "`virtual` can only accept 1/2 positional arguments.".to_string() ); } let arg0 = self.fold_expr(args.remove(0))?; let ty = if let Some(arg) = args.pop() { let top_level_defs = self.top_level.definitions.read(); self.function_data.resolver.parse_type_annotation( top_level_defs.as_slice(), self.unifier, self.primitives, &arg, )? } else { self.unifier.get_fresh_var().0 }; self.virtual_checks.push((arg0.custom.unwrap(), ty)); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); return Ok(Located { location, custom, node: ExprKind::Call { func: Box::new(Located { custom: None, location: func.location, node: ExprKind::Name { id, ctx }, }), args: vec![arg0], keywords: vec![], }, }); } // int64 is special because its argument can be a constant larger than int32 if id == "int64".into() && args.len() == 1 { if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = &args[0].node { let int64: Result = val.try_into(); let custom; if int64.is_ok() { custom = Some(self.primitives.int64); } else { return Err("Integer out of bound".into()); } return Ok(Located { location: args[0].location, custom, node: ExprKind::Constant { value: ast::Constant::Int(val.clone()), kind: kind.clone(), }, }); } } Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } } else { func }; let func = Box::new(self.fold_expr(func)?); let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) .collect::, _>>()?; if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { let sign = sign.borrow(); if sign.vars.is_empty() { return Ok(Located { location, custom: Some(sign.ret), node: ExprKind::Call { func, args, keywords }, }); } } let ret = self.unifier.get_fresh_var().0; let call = self.unifier.add_call(Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords .iter() .map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap())) .collect(), fun: RefCell::new(None), ret, }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); self.unify(func.custom.unwrap(), call, &func.location)?; Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) } fn infer_identifier(&mut self, id: StrRef) -> InferenceResult { if let Some(ty) = self.variable_mapping.get(&id) { Ok(*ty) } else { let variable_mapping = &mut self.variable_mapping; let unifier = &mut self.unifier; Ok(self .function_data .resolver .get_symbol_type(unifier, self.primitives, id) .unwrap_or_else(|| { let ty = unifier.get_fresh_var().0; variable_mapping.insert(id, ty); ty })) } } fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult { match constant { ast::Constant::Bool(_) => Ok(self.primitives.bool), ast::Constant::Int(val) => { let int32: Result = val.try_into(); // int64 would be handled separately in functions if int32.is_ok() { Ok(self.primitives.int32) } else { Err("Integer out of bound".into()) } } ast::Constant::Float(_) => Ok(self.primitives.float), ast::Constant::Tuple(vals) => { let ty: Result, _> = vals.iter().map(|x| self.infer_constant(x)).collect(); Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) } _ => Err("not supported".into()), } } fn infer_list(&mut self, elts: &[ast::Expr>]) -> InferenceResult { let (ty, _) = self.unifier.get_fresh_var(); for t in elts.iter() { self.unify(ty, t.custom.unwrap(), &t.location)?; } Ok(self.unifier.add_ty(TypeEnum::TList { ty })) } fn infer_tuple(&mut self, elts: &[ast::Expr>]) -> InferenceResult { let ty = elts.iter().map(|x| x.custom.unwrap()).collect(); Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) } fn infer_attribute( &mut self, value: &ast::Expr>, attr: StrRef, ) -> InferenceResult { let (attr_ty, _) = self.unifier.get_fresh_var(); let fields = once((attr, attr_ty)).collect(); let record = self.unifier.add_record(fields); self.constrain(value.custom.unwrap(), record, &value.location)?; Ok(attr_ty) } fn infer_bool_ops(&mut self, values: &[ast::Expr>]) -> InferenceResult { let b = self.primitives.bool; for v in values { self.constrain(v.custom.unwrap(), b, &v.location)?; } Ok(b) } fn infer_bin_ops( &mut self, location: Location, left: &ast::Expr>, op: &ast::Operator, right: &ast::Expr>, ) -> InferenceResult { let method = binop_name(op).into(); self.build_method_call( location, method, left.custom.unwrap(), vec![right.custom.unwrap()], None, ) } fn infer_unary_ops( &mut self, op: &ast::Unaryop, operand: &ast::Expr>, ) -> InferenceResult { let method = unaryop_name(op).into(); self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None) } fn infer_compare( &mut self, left: &ast::Expr>, ops: &[ast::Cmpop], comparators: &[ast::Expr>], ) -> InferenceResult { let boolean = self.primitives.bool; for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { let method = comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.into(); self.build_method_call( a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], Some(boolean), )?; } Ok(boolean) } fn infer_subscript( &mut self, value: &ast::Expr>, slice: &ast::Expr>, ) -> InferenceResult { let ty = self.unifier.get_fresh_var().0; match &slice.node { ast::ExprKind::Slice { lower, upper, step } => { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; } let list = self.unifier.add_ty(TypeEnum::TList { ty }); self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(list) } ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => { // the index is a constant, so value can be a sequence. let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?; let map = once((ind, ty)).collect(); let seq = self.unifier.add_sequence(map); self.constrain(value.custom.unwrap(), seq, &value.location)?; Ok(ty) } _ => { // the index is not a constant, so value can only be a list self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; let list = self.unifier.add_ty(TypeEnum::TList { ty }); self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(ty) } } } fn infer_if_expr( &mut self, test: &ast::Expr>, body: &ast::Expr>, orelse: &ast::Expr>, ) -> InferenceResult { self.constrain(test.custom.unwrap(), self.primitives.bool, &test.location)?; let ty = self.unifier.get_fresh_var().0; self.constrain(body.custom.unwrap(), ty, &body.location)?; self.constrain(orelse.custom.unwrap(), ty, &orelse.location)?; Ok(ty) } }