diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 29868941..c114f67f 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -35,13 +35,11 @@ impl<'a> Inferencer<'a> { ) -> Result<(), String> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { - let ty = self.unifier.get_ty(*ty); - let ty = ty.as_ref(); - if !ty.is_concrete() { + if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { return Err(format!( "expected concrete type at {} but got {}", expr.location, - ty.get_type_name() + self.unifier.get_ty(*ty).get_type_name() )); } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 754ecd97..f852fab9 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -25,14 +25,18 @@ pub struct PrimitiveStore { pub none: Type, } +pub struct FunctionData { + pub resolver: Box, + pub return_type: Option, + pub bound_variables: Vec, +} + pub struct Inferencer<'a> { - pub resolver: &'a mut Box, + pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, + pub primitives: &'a PrimitiveStore, pub virtual_checks: &'a mut Vec<(Type, Type)>, pub variable_mapping: HashMap, - pub calls: &'a mut Vec>, - pub primitives: &'a PrimitiveStore, - pub return_type: Option, } struct NaiveFolder(); @@ -65,6 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { None }; let annotation_type = self + .function_data .resolver .parse_type_name(annotation.as_ref()) .ok_or_else(|| "cannot parse type name".to_string())?; @@ -93,7 +98,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} ast::StmtKind::Break | ast::StmtKind::Continue => {} - ast::StmtKind::Return { value } => match (value, self.return_type) { + ast::StmtKind::Return { value } => match (value, self.function_data.return_type) { (Some(v), Some(v1)) => { self.unifier.unify(v.custom.unwrap(), v1)?; } @@ -171,7 +176,6 @@ impl<'a> Inferencer<'a> { ) -> InferenceResult { let call = Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) }); - self.calls.push(call.clone()); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let fields = once((method, call)).collect(); let record = self.unifier.add_record(fields); @@ -207,13 +211,11 @@ impl<'a> Inferencer<'a> { variable_mapping.extend(fn_args.iter().cloned()); let ret = self.unifier.get_fresh_var().0; let mut new_context = Inferencer { - resolver: self.resolver, + function_data: self.function_data, unifier: self.unifier, + primitives: self.primitives, virtual_checks: self.virtual_checks, variable_mapping, - calls: self.calls, - primitives: self.primitives, - return_type: self.return_type, }; let fun = FunSignature { args: fn_args @@ -250,13 +252,11 @@ impl<'a> Inferencer<'a> { } let variable_mapping = self.variable_mapping.clone(); let mut new_context = Inferencer { - resolver: self.resolver, + function_data: self.function_data, unifier: self.unifier, virtual_checks: self.virtual_checks, variable_mapping, - calls: self.calls, primitives: self.primitives, - return_type: self.return_type, }; let elt = new_context.fold_expr(elt)?; let generator = generators.pop().unwrap(); @@ -315,7 +315,7 @@ impl<'a> Inferencer<'a> { } let arg0 = self.fold_expr(args.remove(0))?; let ty = if let Some(arg) = args.pop() { - self.resolver + self.function_data.resolver .parse_type_name(&arg) .ok_or_else(|| "error parsing type".to_string())? } else { @@ -379,7 +379,6 @@ impl<'a> Inferencer<'a> { fun: RefCell::new(None), ret, }); - self.calls.push(call.clone()); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); self.unifier.unify(func.custom.unwrap(), call)?; @@ -390,7 +389,7 @@ impl<'a> Inferencer<'a> { if let Some(ty) = self.variable_mapping.get(id) { Ok(*ty) } else { - Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| { + Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| { let ty = self.unifier.get_fresh_var().0; self.variable_mapping.insert(id.to_string(), ty); ty diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index bae64f28..16a5ffff 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -37,8 +37,7 @@ impl SymbolResolver for Resolver { struct TestEnvironment { pub unifier: Unifier, - pub resolver: Box, - pub calls: Vec>, + pub function_data: FunctionData, pub primitives: PrimitiveStore, pub id_to_name: HashMap, pub identifier_mapping: HashMap, @@ -149,24 +148,25 @@ impl TestEnvironment { TestEnvironment { unifier, - resolver, + function_data: FunctionData { + resolver, + bound_variables: Vec::new(), + return_type: None + }, primitives, id_to_name, identifier_mapping, - calls: Vec::new(), virtual_checks: Vec::new(), } } fn get_inferencer(&mut self) -> Inferencer { Inferencer { - resolver: &mut self.resolver, + function_data: &mut self.function_data, unifier: &mut self.unifier, variable_mapping: Default::default(), - calls: &mut self.calls, primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, - return_type: None, } } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 54b8aeb2..c97926da 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -83,10 +83,6 @@ impl TypeEnum { TypeEnum::TFunc { .. } => "TFunc", } } - - pub fn is_concrete(&self) -> bool { - !matches!(self, TypeEnum::TVar { .. }) - } } pub struct Unifier { @@ -143,6 +139,23 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id) } + pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { + use TypeEnum::*; + match &*self.get_ty(a) { + TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), + TCall { .. } => false, + TList { ty } => self.is_concrete(*ty, allowed_typevars), + TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), + TObj { params: vars, .. } => { + vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) + } + // functions are instantiated for each call sites, so the function type can contain + // type variables. + TFunc { .. } => true, + TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), + } + } + pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { if self.unification_table.unioned(a, b) { Ok(()) @@ -204,7 +217,7 @@ impl Unifier { } for v1 in old_range2.iter() { for v2 in range1.iter() { - if let Ok(result) = self.get_intersection(*v1, *v2){ + if let Ok(result) = self.get_intersection(*v1, *v2) { range2.push(result.unwrap_or(*v2)); } } @@ -486,7 +499,7 @@ impl Unifier { Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name())) } - /// Instantiate a function if it hasn't been instntiated. + /// Instantiate a function if it hasn't been instantiated. /// Returns Some(T) where T is the instantiated type. /// Returns None if the function is already instantiated. fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {