From a7e3eeea0dd73a63a0d74a3d23afaa7f77f613b2 Mon Sep 17 00:00:00 2001 From: CrescentonC Date: Mon, 2 Aug 2021 17:36:37 +0800 Subject: [PATCH] add primitive magic method support; change from TypeEnum::TObj { fields: Mapping, ..} to TypeEnum::TObj {fields: RefCell>, .. } for interior mutability --- nac3core/src/typecheck/magic_methods.rs | 123 ++++++++------- .../src/typecheck/type_inferencer/test.rs | 146 ++++++++++++++++-- nac3core/src/typecheck/typedef/mod.rs | 12 +- nac3core/src/typecheck/typedef/test.rs | 10 +- 4 files changed, 208 insertions(+), 83 deletions(-) diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 8bf545c7..eacf8180 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,5 +1,5 @@ +use std::borrow::Borrow; use std::collections::HashMap; -use std::rc::Rc; use rustpython_parser::ast::{Cmpop, Operator, Unaryop}; pub fn binop_name(op: &Operator) -> &'static str { @@ -64,14 +64,14 @@ use rustpython_parser::ast; /// Add, Sub, Mult, Pow pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { - if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { for op in &[ ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult, ast::Operator::Pow, ] { - fields.insert( + fields.borrow_mut().insert( binop_name(op).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -84,7 +84,7 @@ pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: })) ); - fields.insert( + fields.borrow_mut().insert( binop_assign_name(op).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.none, @@ -102,7 +102,7 @@ pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: /// LShift, RShift, BitOr, BitXor, BitAnd pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { for op in &[ ast::Operator::LShift, ast::Operator::RShift, @@ -110,7 +110,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty ast::Operator::BitXor, ast::Operator::BitAnd, ] { - fields.insert( + fields.borrow_mut().insert( binop_name(op).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ty, @@ -123,7 +123,7 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty })) ); - fields.insert( + fields.borrow_mut().insert( binop_assign_name(op).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.none, @@ -141,8 +141,8 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty /// Div pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { - fields.insert( + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { + fields.borrow_mut().insert( binop_name(&ast::Operator::Div).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature{ ret: store.float, @@ -155,7 +155,7 @@ pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t })) ); - fields.insert( + fields.borrow_mut().insert( binop_assign_name(&ast::Operator::Div).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature{ ret: store.none, @@ -172,8 +172,8 @@ pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t /// FloorDiv pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { - fields.insert( + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { + fields.borrow_mut().insert( binop_name(&ast::Operator::FloorDiv).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature{ ret: ret_ty, @@ -186,7 +186,7 @@ pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ot })) ); - fields.insert( + fields.borrow_mut().insert( binop_assign_name(&ast::Operator::FloorDiv).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature{ ret: store.none, @@ -203,8 +203,8 @@ pub fn impl_floordiv(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ot /// Mod pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type, ret_ty: Type) { - if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { - fields.insert( + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { + fields.borrow_mut().insert( binop_name(&ast::Operator::Mod).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -217,7 +217,7 @@ pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t })) ); - fields.insert( + fields.borrow_mut().insert( binop_assign_name(&ast::Operator::Mod).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.none, @@ -234,12 +234,12 @@ pub fn impl_mod(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_t /// UAdd, USub pub fn impl_unary_op(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { for op in &[ ast::Unaryop::UAdd, ast::Unaryop::USub ] { - fields.insert( + fields.borrow_mut().insert( unaryop_name(op).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ty, @@ -253,8 +253,8 @@ pub fn impl_unary_op(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { /// Invert pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { - if let Some(TypeEnum::TObj {fields, .. }) = Rc::get_mut(&mut unifier.get_ty(ty)) { - fields.insert( + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { + fields.borrow_mut().insert( unaryop_name(&ast::Unaryop::Invert).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ty, @@ -267,8 +267,8 @@ pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { /// Not pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { - fields.insert( + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { + fields.borrow_mut().insert( unaryop_name(&ast::Unaryop::Not).into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.bool, @@ -281,14 +281,14 @@ pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { /// Lt, LtE, Gt, GtE pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { for op in &[ ast::Cmpop::Lt, ast::Cmpop::LtE, ast::Cmpop::Gt, ast::Cmpop::GtE, ] { - fields.insert( + fields.borrow_mut().insert( comparison_name(op).unwrap().into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.bool, @@ -306,12 +306,12 @@ pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, /// Eq, NotEq pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - if let Some(TypeEnum::TObj {fields, ..}) = Rc::get_mut(&mut unifier.get_ty(ty)) { + if let TypeEnum::TObj {fields, .. } = unifier.get_ty(ty).borrow() { for op in &[ ast::Cmpop::Eq, ast::Cmpop::NotEq, ] { - fields.insert( + fields.borrow_mut().insert( comparison_name(op).unwrap().into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: store.bool, @@ -335,68 +335,71 @@ pub fn set_primirives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, none: _none_t } = *store; - // int32 -------- + /* int32 ======== */ impl_basic_arithmetic(unifier, store, int32_t, int32_t, int32_t); - impl_basic_arithmetic(unifier, store, int32_t, int64_t, int64_t); - impl_basic_arithmetic(unifier, store, int32_t, float_t, float_t); + // impl_basic_arithmetic(unifier, store, int32_t, int64_t, int64_t); + // impl_basic_arithmetic(unifier, store, int32_t, float_t, float_t); impl_bitwise_arithmetic(unifier, store, int32_t); + // impl_div(unifier, store, int32_t, int32_t); + // impl_div(unifier, store, int32_t, int64_t); impl_div(unifier, store, int32_t, int32_t); - impl_div(unifier, store, int32_t, int64_t); - impl_div(unifier, store, int32_t, float_t); impl_floordiv(unifier, store, int32_t, int32_t, int32_t); - impl_floordiv(unifier, store, int32_t, int64_t, int32_t); - impl_floordiv(unifier, store, int32_t, float_t, float_t); + // impl_floordiv(unifier, store, int32_t, int64_t, int32_t); + // impl_floordiv(unifier, store, int32_t, float_t, float_t); impl_mod(unifier, store, int32_t, int32_t, int32_t); - impl_mod(unifier, store, int32_t, int64_t, int32_t); - impl_mod(unifier, store, int32_t, float_t, float_t); + // impl_mod(unifier, store, int32_t, int64_t, int32_t); + // impl_mod(unifier, store, int32_t, float_t, float_t); impl_unary_op(unifier, store, int32_t); impl_invert(unifier, store, int32_t); impl_not(unifier, store, int32_t); impl_comparison(unifier, store, int32_t, int32_t); - impl_comparison(unifier, store, int32_t, int64_t); - impl_comparison(unifier, store, int32_t, float_t); + // impl_comparison(unifier, store, int32_t, int64_t); + // impl_comparison(unifier, store, int32_t, float_t); impl_eq(unifier, store, int32_t); - // int64 -------- - impl_basic_arithmetic(unifier, store, int64_t, int32_t, int64_t); + + /* int64 ======== */ + // impl_basic_arithmetic(unifier, store, int64_t, int32_t, int64_t); impl_basic_arithmetic(unifier, store, int64_t, int64_t, int64_t); - impl_basic_arithmetic(unifier, store, int64_t, float_t, float_t); + // impl_basic_arithmetic(unifier, store, int64_t, float_t, float_t); impl_bitwise_arithmetic(unifier, store, int64_t); - impl_div(unifier, store, int64_t, int32_t); + // impl_div(unifier, store, int64_t, int32_t); impl_div(unifier, store, int64_t, int64_t); - impl_div(unifier, store, int64_t, float_t); - impl_floordiv(unifier, store, int64_t, int32_t, int64_t); + // impl_div(unifier, store, int64_t, float_t); + // impl_floordiv(unifier, store, int64_t, int32_t, int64_t); impl_floordiv(unifier, store, int64_t, int64_t, int64_t); - impl_floordiv(unifier, store, int64_t, float_t, float_t); - impl_mod(unifier, store, int64_t, int32_t, int64_t); + // impl_floordiv(unifier, store, int64_t, float_t, float_t); + // impl_mod(unifier, store, int64_t, int32_t, int64_t); impl_mod(unifier, store, int64_t, int64_t, int64_t); - impl_mod(unifier, store, int64_t, float_t, float_t); + // impl_mod(unifier, store, int64_t, float_t, float_t); impl_unary_op(unifier, store, int64_t); impl_invert(unifier, store, int64_t); impl_not(unifier, store, int64_t); - impl_comparison(unifier, store, int64_t, int32_t); + // impl_comparison(unifier, store, int64_t, int32_t); impl_comparison(unifier, store, int64_t, int64_t); - impl_comparison(unifier, store, int64_t, float_t); + // impl_comparison(unifier, store, int64_t, float_t); impl_eq(unifier, store, int64_t); - // float -------- - impl_basic_arithmetic(unifier, store, float_t, int32_t, float_t); - impl_basic_arithmetic(unifier, store, float_t, int64_t, float_t); + + /* float ======== */ + // impl_basic_arithmetic(unifier, store, float_t, int32_t, float_t); + // impl_basic_arithmetic(unifier, store, float_t, int64_t, float_t); impl_basic_arithmetic(unifier, store, float_t, float_t, float_t); - impl_div(unifier, store, float_t, int32_t); - impl_div(unifier, store, float_t, int64_t); + // impl_div(unifier, store, float_t, int32_t); + // impl_div(unifier, store, float_t, int64_t); impl_div(unifier, store, float_t, float_t); - impl_floordiv(unifier, store, float_t, int32_t, float_t); - impl_floordiv(unifier, store, float_t, int64_t, float_t); + // impl_floordiv(unifier, store, float_t, int32_t, float_t); + // impl_floordiv(unifier, store, float_t, int64_t, float_t); impl_floordiv(unifier, store, float_t, float_t, float_t); - impl_mod(unifier, store, float_t, int32_t, float_t); - impl_mod(unifier, store, float_t, int64_t, float_t); + // impl_mod(unifier, store, float_t, int32_t, float_t); + // impl_mod(unifier, store, float_t, int64_t, float_t); impl_mod(unifier, store, float_t, float_t, float_t); impl_unary_op(unifier, store, float_t); impl_not(unifier, store, float_t); - impl_comparison(unifier, store, float_t, int32_t); - impl_comparison(unifier, store, float_t, int64_t); + // impl_comparison(unifier, store, float_t, int32_t); + // impl_comparison(unifier, store, float_t, int64_t); impl_comparison(unifier, store, float_t, float_t); impl_eq(unifier, store, float_t); - // bool --------- + + /* bool ======== */ impl_not(unifier, store, bool_t); impl_eq(unifier, store, bool_t); } \ No newline at end of file diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index eec99f2a..16200dc9 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -51,31 +51,32 @@ impl TestEnvironment { let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: 0, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: 1, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: 2, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: 3, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: 4, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); // identifier_mapping.insert("None".into(), none); let primitives = PrimitiveStore { int32, int64, float, bool, none }; + set_primirives_magic_methods(&primitives, &mut unifier); let id_to_name = [ @@ -119,27 +120,27 @@ impl TestEnvironment { let mut identifier_mapping = HashMap::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: 0, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: 1, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: 2, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: 3, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: 4, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }); identifier_mapping.insert("None".into(), none); @@ -150,7 +151,7 @@ impl TestEnvironment { let foo_ty = unifier.add_ty(TypeEnum::TObj { obj_id: 5, - fields: [("a".into(), v0)].iter().cloned().collect(), + fields: [("a".into(), v0)].iter().cloned().collect::>().into(), params: [(id, v0)].iter().cloned().collect(), }); @@ -170,7 +171,7 @@ impl TestEnvironment { })); let bar = unifier.add_ty(TypeEnum::TObj { obj_id: 6, - fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect(), + fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect::>().into(), params: Default::default(), }); identifier_mapping.insert( @@ -184,7 +185,7 @@ impl TestEnvironment { let bar2 = unifier.add_ty(TypeEnum::TObj { obj_id: 7, - fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect(), + fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect::>().into(), params: Default::default(), }); identifier_mapping.insert( @@ -350,3 +351,122 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st assert_eq!(&b, y); } } + +#[test_case(indoc! {" + a = 2 + b = 2 + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + "}, + [("a", "int32"), + ("b", "int32"), + ("c", "int32"), + ("d", "int32"), + ("e", "int32"), + ("f", "float"), + ("g", "int32"), + ("h", "int32")].iter().cloned().collect() + ; "int32")] +#[test_case( + indoc! {" + a = 2.4 + b = 3.6 + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + "}, + [("a", "float"), + ("b", "float"), + ("c", "float"), + ("d", "float"), + ("e", "float"), + ("f", "float"), + ("g", "float"), + ("h", "float")].iter().cloned().collect() + ; "float" +)] +#[test_case( + indoc! {" + a = int64(12312312312) + b = int64(24242424424) + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + i = a == b + j = a > b + k = a < b + l = a != b + "}, + [("a", "int64"), + ("b", "int64"), + ("c", "int64"), + ("d", "int64"), + ("e", "int64"), + ("f", "float"), + ("g", "int64"), + ("h", "int64"), + ("i", "bool"), + ("j", "bool"), + ("k", "bool"), + ("l", "bool")].iter().cloned().collect() + ; "int64" +)] +#[test_case( + indoc! {" + a = True + b = False + c = a == b + d = not a + e = a != b + "}, + [("a", "bool"), + ("b", "bool"), + ("c", "bool"), + ("d", "bool"), + ("e", "bool")].iter().cloned().collect() + ; "boolean" +)] +fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { + println!("source:\n{}", source); + let mut env = TestEnvironment::basic_test_env(); + let id_to_name = std::mem::take(&mut env.id_to_name); + let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); + defined_identifiers.push("virtual".to_string()); + let mut inferencer = env.get_inferencer(); + let statements = parse_program(source).unwrap(); + let statements = statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); + + for (k, v) in inferencer.variable_mapping.iter() { + let name = inferencer.unifier.stringify( + *v, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + println!("{}: {}", k, name); + } + for (k, v) in mapping.iter() { + let ty = inferencer.variable_mapping.get(*k).unwrap(); + let name = inferencer.unifier.stringify( + *ty, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); + } +} \ No newline at end of file diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e3dc8337..dbe739d1 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -64,7 +64,7 @@ pub enum TypeEnum { }, TObj { obj_id: usize, - fields: Mapping, + fields: RefCell>, params: VarMap, }, TVirtual { @@ -373,7 +373,8 @@ impl Unifier { (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { self.occur_check(a, b)?; for (k, v) in map.borrow().iter() { - let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?; + let temp = fields.borrow(); + let ty = temp.get(k).ok_or_else(|| format!("No such attribute {}", k))?; self.unify(*ty, *v)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); @@ -385,7 +386,8 @@ impl Unifier { let ty = self.get_ty(*ty); if let TObj { fields, .. } = ty.as_ref() { for (k, v) in map.borrow().iter() { - let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?; + let temp = fields.borrow(); + let ty = temp.get(k).ok_or_else(|| format!("No such attribute {}", k))?; if !matches!(self.get_ty(*ty).as_ref(), TFunc { .. }) { return Err(format!("Cannot access field {} for virtual type", k)); } @@ -659,8 +661,8 @@ impl Unifier { if need_subst { let obj_id = *obj_id; let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone()); - let fields = self.subst_map(&fields, mapping).unwrap_or_else(|| fields.clone()); - Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields })) + let fields = self.subst_map(&fields.borrow(), mapping).unwrap_or_else(|| fields.borrow().clone()); + Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() })) } else { None } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 78cb77c5..f05816db 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -79,7 +79,7 @@ impl TestEnvironment { "int".into(), unifier.add_ty(TypeEnum::TObj { obj_id: 0, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }), ); @@ -87,7 +87,7 @@ impl TestEnvironment { "float".into(), unifier.add_ty(TypeEnum::TObj { obj_id: 1, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }), ); @@ -95,7 +95,7 @@ impl TestEnvironment { "bool".into(), unifier.add_ty(TypeEnum::TObj { obj_id: 2, - fields: HashMap::new(), + fields: HashMap::new().into(), params: HashMap::new(), }), ); @@ -104,7 +104,7 @@ impl TestEnvironment { "Foo".into(), unifier.add_ty(TypeEnum::TObj { obj_id: 3, - fields: [("a".into(), v0)].iter().cloned().collect(), + fields: [("a".into(), v0)].iter().cloned().collect::>().into(), params: [(id, v0)].iter().cloned().collect(), }), ); @@ -335,7 +335,7 @@ fn test_virtual() { })); let bar = env.unifier.add_ty(TypeEnum::TObj { obj_id: 5, - fields: [("f".to_string(), fun), ("a".to_string(), int)].iter().cloned().collect(), + fields: [("f".to_string(), fun), ("a".to_string(), int)].iter().cloned().collect::>().into(), params: HashMap::new(), }); let v0 = env.unifier.get_fresh_var().0;