diff --git a/nac3core/src/typecheck/symbol_resolver.rs b/nac3core/src/typecheck/symbol_resolver.rs index accb1aae..96003410 100644 --- a/nac3core/src/typecheck/symbol_resolver.rs +++ b/nac3core/src/typecheck/symbol_resolver.rs @@ -1,5 +1,6 @@ use super::typedef::Type; use super::location::Location; +use rustpython_parser::ast::Expr; pub enum SymbolType { TypeName(Type), @@ -16,7 +17,8 @@ pub enum SymbolValue<'a> { } pub trait SymbolResolver { - fn get_symbol_type(&mut self, str: &str) -> Option; + fn get_symbol_type(&mut self, str: &str) -> Option; + fn parse_type_name(&mut self, expr: &Expr<()>) -> Option; fn get_symbol_value(&mut self, str: &str) -> Option; fn get_symbol_location(&mut self, str: &str) -> Option; // handle function call etc. diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs index d26dae6f..18cb87e7 100644 --- a/nac3core/src/typecheck/type_inferencer.rs +++ b/nac3core/src/typecheck/type_inferencer.rs @@ -5,7 +5,7 @@ use std::iter::once; use std::rc::Rc; use super::magic_methods::*; -use super::symbol_resolver::{SymbolResolver, SymbolType}; +use super::symbol_resolver::SymbolResolver; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use itertools::izip; use rustpython_parser::ast::{ @@ -44,7 +44,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { func, args, keywords, - } => unimplemented!(), + } => self.fold_call(node.location, *func, args, keywords)?, ast::ExprKind::Lambda { args, body } => { self.fold_lambda(node.location, *args, *body)? } @@ -71,14 +71,15 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ops, comparators, } => Some(self.infer_compare(left, ops, comparators)?), - ast::ExprKind::Call { .. } => expr.custom, ast::ExprKind::Subscript { value, slice, .. } => { Some(self.infer_subscript(value.as_ref(), slice.as_ref())?) } ast::ExprKind::IfExp { test, body, orelse } => { Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) } - ast::ExprKind::ListComp { .. } | ast::ExprKind::Lambda { .. } => expr.custom, // already computed + ast::ExprKind::ListComp { .. } + | ast::ExprKind::Lambda { .. } + | ast::ExprKind::Call { .. } => expr.custom, // already computed ast::ExprKind::Slice { .. } => None, // we don't need it for slice _ => return Err("not supported yet".into()), }; @@ -243,21 +244,127 @@ impl<'a> Inferencer<'a> { }) } + fn fold_call( + &mut self, + location: Location, + func: ast::Expr<()>, + mut args: Vec>, + keywords: Vec>, + ) -> Result>, String> { + let func = if let Located { + location: func_location, + custom, + node: ExprKind::Name { id, ctx }, + } = func + { + // handle special functions that cannot be typed in the usual way... + if id == "virtual" { + if args.is_empty() || args.len() > 2 || !keywords.is_empty() { + return Err("`virtual` can only accept 1/2 positional arguments.".to_string()); + } + let arg0 = self.fold_expr(args.remove(0))?; + let ty = if let Some(arg) = args.pop() { + self.resolver + .parse_type_name(&arg) + .ok_or_else(|| "error parsing type".to_string())? + } else { + self.unifier.get_fresh_var().0 + }; + let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); + return Ok(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id, ctx }, + }), + args: vec![arg0], + keywords: vec![], + }, + }); + } + // int64 is special because its argument can be a constant larger than int32 + if id == "int64" && args.len() == 1 { + if let ExprKind::Constant { + value: ast::Constant::Int(val), + .. + } = &args[0].node + { + let int64: Result = val.try_into(); + let custom; + if int64.is_ok() { + custom = Some(self.primitives.int64); + } else { + return Err("Integer out of bound".into()); + } + return Ok(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id, ctx }, + }), + args: vec![self.fold_expr(args.pop().unwrap())?], + keywords: vec![], + }, + }); + } + } + Located { + location: func_location, + custom, + node: ExprKind::Name { id, ctx }, + } + } else { + func + }; + let func = Box::new(self.fold_expr(func)?); + let args = args + .into_iter() + .map(|v| self.fold_expr(v)) + .collect::, _>>()?; + let keywords = keywords + .into_iter() + .map(|v| fold::fold_keyword(self, v)) + .collect::, _>>()?; + let ret = self.unifier.get_fresh_var().0; + let call = Rc::new(Call { + posargs: args.iter().map(|v| v.custom.unwrap()).collect(), + kwargs: keywords + .iter() + .map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap())) + .collect(), + fun: RefCell::new(None), + ret, + }); + self.calls.push(call.clone()); + let call = self.unifier.add_ty(TypeEnum::TCall { calls: vec![call] }); + self.unifier.unify(func.custom.unwrap(), call)?; + + Ok(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func, + args, + keywords, + }, + }) + } + fn infer_identifier(&mut self, id: &str) -> InferenceResult { if let Some(ty) = self.variable_mapping.get(id) { Ok(*ty) } else { - match self.resolver.get_symbol_type(id) { - Some(SymbolType::TypeName(_)) => { - Err("Expected expression instead of type".to_string()) - } - Some(SymbolType::Identifier(ty)) => Ok(ty), - None => { - let ty = self.unifier.get_fresh_var().0; - self.variable_mapping.insert(id.to_string(), ty); - Ok(ty) - } - } + Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| { + let ty = self.unifier.get_fresh_var().0; + self.variable_mapping.insert(id.to_string(), ty); + ty + })) } } @@ -268,7 +375,7 @@ impl<'a> Inferencer<'a> { let int32: Result = val.try_into(); // int64 would be handled separately in functions if int32.is_ok() { - Ok(self.primitives.int64) + Ok(self.primitives.int32) } else { Err("Integer out of bound".into()) }