diff --git a/nac3core/src/inference_core.rs b/nac3core/src/inference_core.rs index 1442a96..8e7796a 100644 --- a/nac3core/src/inference_core.rs +++ b/nac3core/src/inference_core.rs @@ -198,3 +198,404 @@ pub fn resolve_call( resolve_call_rec(ctx, &None, obj, func, args) } +#[cfg(test)] +mod tests { + use super::*; + use crate::context::TopLevelContext; + use crate::primitives::*; + use std::rc::Rc; + + fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { + InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) + } + + #[test] + fn test_simple_generic() { + let mut ctx = basic_ctx(); + let v1 = ctx.add_variable(VarDef { + name: "V1", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], + }); + let v1 = ctx.get_variable(v1); + let v2 = ctx.add_variable(VarDef { + name: "V2", + bound: vec![ + ctx.get_primitive(BOOL_TYPE), + ctx.get_primitive(INT32_TYPE), + ctx.get_primitive(FLOAT_TYPE), + ], + }); + let v2 = ctx.get_variable(v2); + let ctx = get_inference_context(ctx); + + assert_eq!( + resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]), + Ok(Some(ctx.get_primitive(INT32_TYPE))) + ); + + assert_eq!( + resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],), + Ok(Some(ctx.get_primitive(INT32_TYPE))) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]), + Ok(Some(ctx.get_primitive(FLOAT_TYPE))) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]), + Err("different domain".to_string()) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[]), + Err("incorrect parameter number".to_string()) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[v1]), + Ok(Some(ctx.get_primitive(FLOAT_TYPE))) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[v2]), + Err("different domain".to_string()) + ); + } + + #[test] + fn test_methods() { + let mut ctx = basic_ctx(); + + let v0 = ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }); + let v0 = ctx.get_variable(v0); + let v1 = ctx.add_variable(VarDef { + name: "V1", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], + }); + let v1 = ctx.get_variable(v1); + let v2 = ctx.add_variable(VarDef { + name: "V2", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], + }); + let v2 = ctx.get_variable(v2); + let v3 = ctx.add_variable(VarDef { + name: "V3", + bound: vec![ + ctx.get_primitive(BOOL_TYPE), + ctx.get_primitive(INT32_TYPE), + ctx.get_primitive(FLOAT_TYPE), + ], + }); + let v3 = ctx.get_variable(v3); + + let int32 = ctx.get_primitive(INT32_TYPE); + let int64 = ctx.get_primitive(INT64_TYPE); + let ctx = get_inference_context(ctx); + + // simple cases + assert_eq!( + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), + Ok(Some(int32.clone())) + ); + + assert_ne!( + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), + Ok(Some(int64.clone())) + ); + + assert_eq!( + resolve_call(&ctx, Some(int32), "__add__", &[int64]), + Err("not equal".to_string()) + ); + + // with type variables + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v1.clone()]), + Ok(Some(v1.clone())) + ); + assert_eq!( + resolve_call(&ctx, Some(v0.clone()), "__add__", &[v2.clone()]), + Err("unbounded type var".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v0]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v2]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v3.clone()]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v1]), + Err("no such function".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v3]), + Err("no such function".to_string()) + ); + } + + #[test] + fn test_multi_generic() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }); + let v0 = ctx.get_variable(v0); + let v1 = ctx.add_variable(VarDef { + name: "V1", + bound: vec![], + }); + let v1 = ctx.get_variable(v1); + let v2 = ctx.add_variable(VarDef { + name: "V2", + bound: vec![], + }); + let v2 = ctx.get_variable(v2); + let v3 = ctx.add_variable(VarDef { + name: "V3", + bound: vec![], + }); + let v3 = ctx.get_variable(v3); + + ctx.add_fn( + "foo", + FnDef { + args: vec![v0.clone(), v0.clone(), v1.clone()], + result: Some(v0.clone()), + }, + ); + + ctx.add_fn( + "foo1", + FnDef { + args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()], + result: Some(v0), + }, + ); + let ctx = get_inference_context(ctx); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), + Err("different variables".to_string()) + ); + + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] + ), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] + ), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()] + ), + Err("different variables".to_string()) + ); + } + + #[test] + fn test_class_generics() { + let mut ctx = basic_ctx(); + + let list = ctx.get_parametric_def_mut(LIST_TYPE); + let t = Rc::new(TypeVariable(list.params[0])); + list.base.methods.insert( + "head", + FnDef { + args: vec![], + result: Some(t.clone()), + }, + ); + list.base.methods.insert( + "append", + FnDef { + args: vec![t], + result: None, + }, + ); + + let v0 = ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }); + let v0 = ctx.get_variable(v0); + let v1 = ctx.add_variable(VarDef { + name: "V1", + bound: vec![], + }); + let v1 = ctx.get_variable(v1); + let ctx = get_inference_context(ctx); + + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), + "head", + &[] + ), + Ok(Some(v0.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), + "append", + &[v0.clone()] + ), + Ok(None) + ); + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0]).into()), + "append", + &[v1] + ), + Err("different variables".to_string()) + ); + } + + #[test] + fn test_virtual_class() { + let mut ctx = basic_ctx(); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![], + }); + + let foo1 = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo1", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![foo], + }); + + let foo2 = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo2", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![foo1], + }); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "bar", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![], + }); + + ctx.add_fn( + "foo", + FnDef { + args: vec![VirtualClassType(foo).into()], + result: None, + }, + ); + ctx.add_fn( + "foo1", + FnDef { + args: vec![VirtualClassType(foo1).into()], + result: None, + }, + ); + let ctx = get_inference_context(ctx); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), + Err("not subtype".to_string()) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), + Err("not subtype".to_string()) + ); + + // virtual class substitution + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), + Err("not subtype".to_string()) + ); + } +}