diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c436076..7ff3e0b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -10,7 +10,13 @@ use itertools::izip; use nac3parser::ast::{ self, fold::{self, Fold}, - Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, + Arguments, + Comprehension, + ExprContext, + ExprKind, + Located, + Location, + StrRef }; #[cfg(test)] @@ -773,6 +779,120 @@ impl<'a> Inferencer<'a> { }) } + /// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise + /// returns [`None`]. + fn try_fold_special_call( + &mut self, + location: Location, + func: &ast::Expr<()>, + args: &mut Vec>, + keywords: &Vec>, + ) -> Result>>, HashSet> { + let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { + return Ok(None) + }; + + // handle special functions that cannot be typed in the usual way... + if id == &"virtual".into() { + if args.is_empty() || args.len() > 2 || !keywords.is_empty() { + return report_error( + "`virtual` can only accept 1/2 positional arguments", + *func_location, + ) + } + let arg0 = self.fold_expr(args.remove(0))?; + let ty = if let Some(arg) = args.pop() { + let top_level_defs = self.top_level.definitions.read(); + self.function_data.resolver.parse_type_annotation( + top_level_defs.as_slice(), + self.unifier, + self.primitives, + &arg, + )? + } else { + self.unifier.get_dummy_var().0 + }; + self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location)); + let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); + return Ok(Some(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + // int64 is special because its argument can be a constant larger than int32 + if id == &"int64".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.int64); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + if id == &"uint32".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint32); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + if id == &"uint64".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint64); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + + Ok(None) + } + fn fold_call( &mut self, location: Location, @@ -780,111 +900,11 @@ impl<'a> Inferencer<'a> { mut args: Vec>, keywords: Vec>, ) -> Result>, HashSet> { - let func = - if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } = - func - { - // handle special functions that cannot be typed in the usual way... - if id == "virtual".into() { - if args.is_empty() || args.len() > 2 || !keywords.is_empty() { - return report_error( - "`virtual` can only accept 1/2 positional arguments", - func_location, - ); - } - let arg0 = self.fold_expr(args.remove(0))?; - let ty = if let Some(arg) = args.pop() { - let top_level_defs = self.top_level.definitions.read(); - self.function_data.resolver.parse_type_annotation( - top_level_defs.as_slice(), - self.unifier, - self.primitives, - &arg, - )? - } else { - self.unifier.get_dummy_var().0 - }; - self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location)); - let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); - return Ok(Located { - location, - custom, - node: ExprKind::Call { - func: Box::new(Located { - custom: None, - location: func.location, - node: ExprKind::Name { id, ctx }, - }), - args: vec![arg0], - keywords: vec![], - }, - }); - } - // int64 is special because its argument can be a constant larger than int32 - if id == "int64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.int64); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == "uint32".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint32); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == "uint64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint64); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } - } else { - func - }; + let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { + return Ok(spec_call_func) + } else { + func + }; let func = Box::new(self.fold_expr(func)?); let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; let keywords = keywords