diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs index 45f641331e..805626e5ae 100644 --- a/nac3core/src/expression.rs +++ b/nac3core/src/expression.rs @@ -172,8 +172,7 @@ fn parse_bin_ops( let left = parse_expr(ctx, sym_table, left)?.ok_or("no value".to_string())?; let right = parse_expr(ctx, sym_table, right)?.ok_or("no value".to_string())?; let fun = binop_name(op); - let mut assumptions = HashMap::new(); - resolve_call(ctx, Some(left), fun, &[right], &mut assumptions) + resolve_call(ctx, Some(left), fun, &[right]) } fn parse_unary_ops( @@ -183,7 +182,6 @@ fn parse_unary_ops( obj: &Expression, ) -> ParserResult { let ty = parse_expr(ctx, sym_table, obj)?.ok_or("no value".to_string())?; - let mut assumptions = HashMap::new(); if let UnaryOperator::Not = op { if ty.as_ref() == &PrimitiveType(BOOL_TYPE) { Ok(Some(ty)) @@ -191,7 +189,7 @@ fn parse_unary_ops( Err("logical not must be applied to bool".into()) } } else { - resolve_call(ctx, Some(ty), unaryop_name(op), &[], &mut assumptions) + resolve_call(ctx, Some(ty), unaryop_name(op), &[]) } } @@ -211,11 +209,10 @@ fn parse_compare( let boolean = PrimitiveType(BOOL_TYPE); let left = &types[..types.len() - 1]; let right = &types[1..]; - let mut assumptions = HashMap::new(); for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; - let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()], &mut assumptions)?; + let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; if ty.is_none() || ty.unwrap().as_ref() != &boolean { return Err("comparison result must be boolean".into()); } @@ -235,7 +232,6 @@ fn parse_call( if types.is_none() { return Err("function params must have type".into()); } - let mut assumptions = HashMap::new(); let (obj, fun) = match &function.node { ExpressionType::Identifier { name } => (None, name), @@ -245,7 +241,7 @@ fn parse_call( ), _ => return Err("not supported".into()), }; - resolve_call(ctx, obj, fun.as_str(), &types.unwrap(), &mut assumptions) + resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) } fn parse_subscript( diff --git a/nac3core/src/inference.rs b/nac3core/src/inference.rs index cdd28ba64b..1eb351e6c0 100644 --- a/nac3core/src/inference.rs +++ b/nac3core/src/inference.rs @@ -4,15 +4,17 @@ use std::rc::Rc; fn find_subst( ctx: &GlobalContext, - assumptions: &HashMap>, + valuation: &Option<(VariableId, Rc)>, sub: &mut HashMap>, mut a: Rc, mut b: Rc, ) -> Result<(), String> { // TODO: fix error messages later if let TypeVariable(id) = a.as_ref() { - if let Some(c) = assumptions.get(&id) { - a = c.clone(); + if let Some((assumption_id, t)) = valuation { + if assumption_id == id { + a = t.clone(); + } } } @@ -97,7 +99,7 @@ fn find_subst( Err("different parametric types".to_string()) } else { for (x, y) in param_a.iter().zip(param_b.iter()) { - find_subst(ctx, assumptions, sub, x.clone(), y.clone())?; + find_subst(ctx, valuation, sub, x.clone(), y.clone())?; } Ok(()) } @@ -112,14 +114,13 @@ fn find_subst( } } -pub fn resolve_call( +fn resolve_call_rec( ctx: &GlobalContext, + valuation: &Option<(VariableId, Rc)>, obj: Option>, func: &str, args: &[Rc], - assumptions: &mut HashMap>, ) -> Result>, String> { - let obj = obj.as_ref().map(|v| Rc::new(v.subst(assumptions))); let mut subst = obj .as_ref() .map(|v| v.get_subst(ctx)) @@ -137,15 +138,15 @@ pub fn resolve_call( .bound .iter() .map(|ins| { - assumptions.insert(*id, ins.clone()); - resolve_call(ctx, Some(obj.clone()), func, args.clone(), assumptions) + resolve_call_rec( + ctx, + &Some((*id, ins.clone())), + Some(ins.clone()), + func, + args.clone(), + ) }) .collect(); - // `assumption` cannot substitute variable for variable, if assumption contains - // this id before running this function, `obj` would not be a variable, so this - // would not be executed. - // Hence, we lose no information doing this. - assumptions.remove(id); let results = results?; if results.iter().all(|v| v == &results[0]) { return Ok(results[0].clone()); @@ -176,7 +177,7 @@ pub fn resolve_call( return Err("incorrect parameter number".to_string()); } for (a, b) in args.iter().zip(fun.args.iter()) { - find_subst(ctx, assumptions, &mut subst, a.clone(), b.clone())?; + find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?; } let result = fun.result.as_ref().map(|v| v.subst(&subst)); Ok(result.map(|result| { @@ -188,6 +189,15 @@ pub fn resolve_call( })) } +pub fn resolve_call( + ctx: &GlobalContext, + obj: Option>, + func: &str, + args: &[Rc], +) -> Result>, String> { + resolve_call_rec(ctx, &None, obj, func, args) +} + #[cfg(test)] mod tests { use super::*; @@ -196,54 +206,29 @@ mod tests { #[test] fn test_simple_generic() { let mut ctx = basic_ctx(); - let mut assumptions = HashMap::new(); assert_eq!( - resolve_call( - &ctx, - None, - "int32", - &[PrimitiveType(FLOAT_TYPE).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "int32", &[PrimitiveType(FLOAT_TYPE).into()]), Ok(Some(PrimitiveType(INT32_TYPE).into())) ); assert_eq!( - resolve_call( - &ctx, - None, - "int32", - &[PrimitiveType(INT32_TYPE).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "int32", &[PrimitiveType(INT32_TYPE).into()],), Ok(Some(PrimitiveType(INT32_TYPE).into())) ); assert_eq!( - resolve_call( - &ctx, - None, - "float", - &[PrimitiveType(INT32_TYPE).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "float", &[PrimitiveType(INT32_TYPE).into()]), Ok(Some(PrimitiveType(FLOAT_TYPE).into())) ); assert_eq!( - resolve_call( - &ctx, - None, - "float", - &[PrimitiveType(BOOL_TYPE).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "float", &[PrimitiveType(BOOL_TYPE).into()]), Err("different domain".to_string()) ); assert_eq!( - resolve_call(&ctx, None, "float", &[], &mut assumptions), + resolve_call(&ctx, None, "float", &[]), Err("incorrect parameter number".to_string()) ); @@ -256,13 +241,7 @@ mod tests { }); assert_eq!( - resolve_call( - &ctx, - None, - "float", - &[TypeVariable(v1).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "float", &[TypeVariable(v1).into()]), Ok(Some(PrimitiveType(FLOAT_TYPE).into())) ); @@ -276,13 +255,7 @@ mod tests { }); assert_eq!( - resolve_call( - &ctx, - None, - "float", - &[TypeVariable(v2).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "float", &[TypeVariable(v2).into()]), Err("different domain".to_string()) ); } @@ -290,7 +263,6 @@ mod tests { #[test] fn test_methods() { let mut ctx = basic_ctx(); - let mut assumptions = HashMap::new(); let v0 = Rc::new(TypeVariable(ctx.add_variable(VarDef { name: "V0", @@ -324,107 +296,47 @@ mod tests { // simple cases assert_eq!( - resolve_call( - &ctx, - Some(int32.clone()), - "__add__", - &[int32.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), Ok(Some(int32.clone())) ); assert_ne!( - resolve_call( - &ctx, - Some(int32.clone()), - "__add__", - &[int32.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), Ok(Some(int64.clone())) ); assert_eq!( - resolve_call( - &ctx, - Some(int32.clone()), - "__add__", - &[int64.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int64.clone()]), Err("not equal".to_string()) ); // with type variables assert_eq!( - resolve_call( - &ctx, - Some(v1.clone()), - "__add__", - &[v1.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v1.clone()]), Ok(Some(v1.clone())) ); assert_eq!( - resolve_call( - &ctx, - Some(v0.clone()), - "__add__", - &[v2.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v0.clone()), "__add__", &[v2.clone()]), Err("unbounded type var".to_string()) ); assert_eq!( - resolve_call( - &ctx, - Some(v1.clone()), - "__add__", - &[v0.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v0.clone()]), Err("different domain".to_string()) ); assert_eq!( - resolve_call( - &ctx, - Some(v1.clone()), - "__add__", - &[v2.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v2.clone()]), Err("different domain".to_string()) ); assert_eq!( - resolve_call( - &ctx, - Some(v1.clone()), - "__add__", - &[v3.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v3.clone()]), Err("different domain".to_string()) ); assert_eq!( - resolve_call( - &ctx, - Some(v3.clone()), - "__add__", - &[v1.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v1.clone()]), Err("no such function".to_string()) ); assert_eq!( - resolve_call( - &ctx, - Some(v3.clone()), - "__add__", - &[v3.clone()], - &mut assumptions - ), + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v3.clone()]), Err("no such function".to_string()) ); } @@ -432,7 +344,6 @@ mod tests { #[test] fn test_multi_generic() { let mut ctx = basic_ctx(); - let mut assumptions = HashMap::new(); let v0 = Rc::new(TypeVariable(ctx.add_variable(VarDef { name: "V0", bound: vec![], @@ -469,33 +380,15 @@ mod tests { ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[v2.clone(), v2.clone(), v2.clone()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), Ok(Some(v2.clone())) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[v2.clone(), v2.clone(), v3.clone()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), Ok(Some(v2.clone())) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[v2.clone(), v3.clone(), v3.clone()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), Err("different variables".to_string()) ); @@ -504,8 +397,7 @@ mod tests { &ctx, None, "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()], - &mut assumptions + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] ), Ok(Some(v2.clone())) ); @@ -514,8 +406,7 @@ mod tests { &ctx, None, "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()], - &mut assumptions + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] ), Ok(Some(v2.clone())) ); @@ -524,8 +415,7 @@ mod tests { &ctx, None, "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v3.clone(), v3.clone()]).into()], - &mut assumptions + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v3.clone(), v3.clone()]).into()] ), Err("different variables".to_string()) ); @@ -534,7 +424,6 @@ mod tests { #[test] fn test_class_generics() { let mut ctx = basic_ctx(); - let mut assumptions = HashMap::new(); let list = ctx.get_parametric_mut(LIST_TYPE); let t = Rc::new(TypeVariable(list.params[0])); @@ -567,8 +456,7 @@ mod tests { &ctx, Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), "head", - &[], - &mut assumptions + &[] ), Ok(Some(v0.clone())) ); @@ -577,8 +465,7 @@ mod tests { &ctx, Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), "append", - &[v0.clone()], - &mut assumptions + &[v0.clone()] ), Ok(None) ); @@ -587,8 +474,7 @@ mod tests { &ctx, Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), "append", - &[v1.clone()], - &mut assumptions + &[v1.clone()] ), Err("different variables".to_string()) ); @@ -597,7 +483,6 @@ mod tests { #[test] fn test_virtual_class() { let mut ctx = basic_ctx(); - let mut assumptions = HashMap::new(); let foo = ctx.add_class(ClassDef { base: TypeDef { @@ -651,121 +536,55 @@ mod tests { ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[ClassType(foo).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[ClassType(foo1).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[ClassType(foo2).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[ClassType(bar).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), Err("not subtype".to_string()) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ClassType(foo1).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ClassType(foo2).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ClassType(foo).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), Err("not subtype".to_string()) ); // virtual class substitution assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[VirtualClassType(foo).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[VirtualClassType(foo1).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[VirtualClassType(foo2).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), Ok(None) ); assert_eq!( - resolve_call( - &ctx, - None, - "foo", - &[VirtualClassType(bar).into()], - &mut assumptions - ), + resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), Err("not subtype".to_string()) ); }