use super::context::InferenceContext; use super::typedef::{TypeEnum::*, *}; use std::collections::HashMap; use thiserror::Error; #[derive(Error, Debug)] enum SubstError { #[error("different type variables after substitution")] DifferentSubstVar(VariableId, VariableId), #[error("cannot substitute unbounded type variable into bounded one")] UnboundedTypeVar(VariableId, VariableId), #[error("incompatible bound for type variables")] IncompatibleBound(VariableId, VariableId), #[error("only subtype of virtual class can be substituted into virtual class type")] NotVirtualClassSubtype(Type, ClassId), #[error("different types")] DifferentTypes(Type, Type), } fn find_subst( ctx: &InferenceContext, valuation: &Option<(VariableId, Type)>, sub: &mut HashMap, mut a: Type, mut b: Type, ) -> Result<(), SubstError> { if let TypeVariable(id) = a.as_ref() { if let Some((assumption_id, t)) = valuation { if assumption_id == id { a = t.clone(); } } } let mut substituted = false; if let TypeVariable(id) = b.as_ref() { if let Some(c) = sub.get(&id) { b = c.clone(); substituted = true; } } match (a.as_ref(), b.as_ref()) { (BotType, _) => Ok(()), (TypeVariable(id_a), TypeVariable(id_b)) => { if substituted { return if id_a == id_b { Ok(()) } else { Err(SubstError::DifferentSubstVar(*id_a, *id_b)) }; } let v_a = ctx.get_variable_def(*id_a); let v_b = ctx.get_variable_def(*id_b); if !v_b.bound.is_empty() { if v_a.bound.is_empty() { return Err(SubstError::UnboundedTypeVar(*id_a, *id_b)); } else { let diff: Vec<_> = v_a .bound .iter() .filter(|x| !v_b.bound.contains(x)) .collect(); if !diff.is_empty() { return Err(SubstError::IncompatibleBound(*id_a, *id_b)); } } } sub.insert(*id_b, a.clone()); Ok(()) } (TypeVariable(id_a), _) => { let v_a = ctx.get_variable_def(*id_a); if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { Ok(()) } else { Err(SubstError::DifferentTypes(a.clone(), b.clone())) } } (_, TypeVariable(id_b)) => { let v_b = ctx.get_variable_def(*id_b); if v_b.bound.is_empty() || v_b.bound.contains(&a) { sub.insert(*id_b, a.clone()); Ok(()) } else { Err(SubstError::DifferentTypes(a.clone(), b.clone())) } } (_, VirtualClassType(id_b)) => { let mut parents; match a.as_ref() { ClassType(id_a) => { parents = [*id_a].to_vec(); } VirtualClassType(id_a) => { parents = [*id_a].to_vec(); } _ => { return Err(SubstError::NotVirtualClassSubtype(a.clone(), *id_b)); } }; while !parents.is_empty() { if *id_b == parents[0] { return Ok(()); } let c = ctx.get_class_def(parents.remove(0)); parents.extend_from_slice(&c.parents); } Err(SubstError::NotVirtualClassSubtype(a.clone(), *id_b)) } (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { if id_a != id_b || param_a.len() != param_b.len() { Err(SubstError::DifferentTypes(a.clone(), b.clone())) } else { for (x, y) in param_a.iter().zip(param_b.iter()) { find_subst(ctx, valuation, sub, x.clone(), y.clone())?; } Ok(()) } } (_, _) => { if a == b { Ok(()) } else { Err(SubstError::DifferentTypes(a.clone(), b.clone())) } } } } fn resolve_call_rec( ctx: &InferenceContext, valuation: &Option<(VariableId, Type)>, obj: Option, func: &str, args: &[Type], ) -> Result, String> { let mut subst = obj .as_ref() .map(|v| v.get_subst(ctx)) .unwrap_or_else(HashMap::new); let fun = match &obj { Some(obj) => { let base = match obj.as_ref() { TypeVariable(id) => { let v = ctx.get_variable_def(*id); if v.bound.is_empty() { return Err("unbounded type var".to_string()); } let results: Result, String> = v .bound .iter() .map(|ins| { resolve_call_rec( ctx, &Some((*id, ins.clone())), Some(ins.clone()), func, args.clone(), ) }) .collect(); let results = results?; if results.iter().all(|v| v == &results[0]) { return Ok(results[0].clone()); } let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| { r.as_ref() .map(|v| v.inv_subst(&[(ins.clone(), obj.clone())])) }); let first = results.next().unwrap(); if results.all(|v| v == first) { return Ok(first); } else { return Err("divergent type after substitution".to_string()); } } PrimitiveType(id) => &ctx.get_primitive_def(*id), ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base, ParametricType(id, _) => &ctx.get_parametric_def(*id).base, _ => return Err("not supported".to_string()), }; base.methods.get(func) } None => ctx.get_fn_def(func), } .ok_or_else(|| "no such function".to_string())?; if args.len() != fun.args.len() { return Err("incorrect parameter number".to_string()); } for (a, b) in args.iter().zip(fun.args.iter()) { find_subst(ctx, valuation, &mut subst, a.clone(), b.clone()).map_err(|v| v.to_string())?; } let result = fun.result.as_ref().map(|v| v.subst(&subst)); Ok(result.map(|result| { if let SelfType = result { obj.unwrap() } else { result.into() } })) } pub fn resolve_call( ctx: &InferenceContext, obj: Option, func: &str, args: &[Type], ) -> Result, String> { resolve_call_rec(ctx, &None, obj, func, args) } #[cfg(test)] mod tests { use super::{ super::{context::*, primitives::*}, *, }; use std::matches; use std::rc::Rc; fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) } #[test] fn test_simple_generic() { let mut ctx = basic_ctx(); let v1 = ctx.add_variable(VarDef { name: "V1", bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], }); let v1 = ctx.get_variable(v1); let v2 = ctx.add_variable(VarDef { name: "V2", bound: vec![ ctx.get_primitive(BOOL_TYPE), ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE), ], }); let v2 = ctx.get_variable(v2); let ctx = get_inference_context(ctx); assert_eq!( resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]), Ok(Some(ctx.get_primitive(INT32_TYPE))) ); assert_eq!( resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],), Ok(Some(ctx.get_primitive(INT32_TYPE))) ); assert_eq!( resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]), Ok(Some(ctx.get_primitive(FLOAT_TYPE))) ); assert!(matches!( resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]), Err(..) )); assert!(matches!( resolve_call(&ctx, None, "float", &[]), Err(..) )); assert_eq!( resolve_call(&ctx, None, "float", &[v1]), Ok(Some(ctx.get_primitive(FLOAT_TYPE))) ); assert!(matches!( resolve_call(&ctx, None, "float", &[v2]), Err(..) )); } #[test] fn test_methods() { let mut ctx = basic_ctx(); let v0 = ctx.add_variable(VarDef { name: "V0", bound: vec![], }); let v0 = ctx.get_variable(v0); let v1 = ctx.add_variable(VarDef { name: "V1", bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], }); let v1 = ctx.get_variable(v1); let v2 = ctx.add_variable(VarDef { name: "V2", bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], }); let v2 = ctx.get_variable(v2); let v3 = ctx.add_variable(VarDef { name: "V3", bound: vec![ ctx.get_primitive(BOOL_TYPE), ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE), ], }); let v3 = ctx.get_variable(v3); let int32 = ctx.get_primitive(INT32_TYPE); let int64 = ctx.get_primitive(INT64_TYPE); let ctx = get_inference_context(ctx); // simple cases assert_eq!( resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), Ok(Some(int32.clone())) ); assert_ne!( resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), Ok(Some(int64.clone())) ); assert!(matches!( resolve_call(&ctx, Some(int32), "__add__", &[int64]), Err(..) )); // with type variables assert_eq!( resolve_call(&ctx, Some(v1.clone()), "__add__", &[v1.clone()]), Ok(Some(v1.clone())) ); assert!(matches!( resolve_call(&ctx, Some(v0.clone()), "__add__", &[v2.clone()]), Err(..) )); assert!(matches!( resolve_call(&ctx, Some(v1.clone()), "__add__", &[v0]), Err(..) )); assert!(matches!( resolve_call(&ctx, Some(v1.clone()), "__add__", &[v2]), Err(..) )); assert!(matches!( resolve_call(&ctx, Some(v1.clone()), "__add__", &[v3.clone()]), Err(..) )); assert!(matches!( resolve_call(&ctx, Some(v3.clone()), "__add__", &[v1]), Err(..) )); assert!(matches!( resolve_call(&ctx, Some(v3.clone()), "__add__", &[v3]), Err(..) )); } #[test] fn test_multi_generic() { let mut ctx = basic_ctx(); let v0 = ctx.add_variable(VarDef { name: "V0", bound: vec![], }); let v0 = ctx.get_variable(v0); let v1 = ctx.add_variable(VarDef { name: "V1", bound: vec![], }); let v1 = ctx.get_variable(v1); let v2 = ctx.add_variable(VarDef { name: "V2", bound: vec![], }); let v2 = ctx.get_variable(v2); let v3 = ctx.add_variable(VarDef { name: "V3", bound: vec![], }); let v3 = ctx.get_variable(v3); ctx.add_fn( "foo", FnDef { args: vec![v0.clone(), v0.clone(), v1.clone()], result: Some(v0.clone()), }, ); ctx.add_fn( "foo1", FnDef { args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()], result: Some(v0), }, ); let ctx = get_inference_context(ctx); assert_eq!( resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), Ok(Some(v2.clone())) ); assert_eq!( resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), Ok(Some(v2.clone())) ); assert!(matches!( resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), Err(..) )); assert_eq!( resolve_call( &ctx, None, "foo1", &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] ), Ok(Some(v2.clone())) ); assert_eq!( resolve_call( &ctx, None, "foo1", &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] ), Ok(Some(v2.clone())) ); assert!(matches!( resolve_call( &ctx, None, "foo1", &[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()] ), Err(..) )); } #[test] fn test_class_generics() { let mut ctx = basic_ctx(); let list = ctx.get_parametric_def_mut(LIST_TYPE); let t = Rc::new(TypeVariable(list.params[0])); list.base.methods.insert( "head", FnDef { args: vec![], result: Some(t.clone()), }, ); list.base.methods.insert( "append", FnDef { args: vec![t], result: None, }, ); let v0 = ctx.add_variable(VarDef { name: "V0", bound: vec![], }); let v0 = ctx.get_variable(v0); let v1 = ctx.add_variable(VarDef { name: "V1", bound: vec![], }); let v1 = ctx.get_variable(v1); let ctx = get_inference_context(ctx); assert_eq!( resolve_call( &ctx, Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), "head", &[] ), Ok(Some(v0.clone())) ); assert_eq!( resolve_call( &ctx, Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), "append", &[v0.clone()] ), Ok(None) ); assert!(matches!( resolve_call( &ctx, Some(ParametricType(LIST_TYPE, vec![v0]).into()), "append", &[v1] ), Err(..) )); } #[test] fn test_virtual_class() { let mut ctx = basic_ctx(); let foo = ctx.add_class(ClassDef { base: TypeDef { name: "Foo", methods: HashMap::new(), fields: HashMap::new(), }, parents: vec![], }); let foo1 = ctx.add_class(ClassDef { base: TypeDef { name: "Foo1", methods: HashMap::new(), fields: HashMap::new(), }, parents: vec![foo], }); let foo2 = ctx.add_class(ClassDef { base: TypeDef { name: "Foo2", methods: HashMap::new(), fields: HashMap::new(), }, parents: vec![foo1], }); let bar = ctx.add_class(ClassDef { base: TypeDef { name: "bar", methods: HashMap::new(), fields: HashMap::new(), }, parents: vec![], }); ctx.add_fn( "foo", FnDef { args: vec![VirtualClassType(foo).into()], result: None, }, ); ctx.add_fn( "foo1", FnDef { args: vec![VirtualClassType(foo1).into()], result: None, }, ); let ctx = get_inference_context(ctx); assert_eq!( resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), Ok(None) ); assert_eq!( resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), Ok(None) ); assert_eq!( resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), Ok(None) ); assert!(matches!( resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), Err(..) )); assert_eq!( resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), Ok(None) ); assert_eq!( resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), Ok(None) ); assert!(matches!( resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), Err(..) )); // virtual class substitution assert_eq!( resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), Ok(None) ); assert_eq!( resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), Ok(None) ); assert_eq!( resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), Ok(None) ); assert!(matches!( resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), Err(..) )); } }