From b1e83a1fd47b0e2a1e20db2a3665d705b0ef8c8a Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 6 Nov 2021 22:48:08 +0800 Subject: [PATCH] nac3core: type check invariants This rejects code that tries to assign to KernelInvariant fields and methods. --- nac3artiq/src/symbol_resolver.rs | 10 +- nac3core/src/codegen/concrete_type.rs | 6 +- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/symbol_resolver.rs | 33 +++--- nac3core/src/toplevel/composer.rs | 53 ++++----- nac3core/src/toplevel/helper.rs | 2 +- nac3core/src/toplevel/mod.rs | 3 +- nac3core/src/toplevel/type_annotation.rs | 20 +--- nac3core/src/typecheck/magic_methods.rs | 90 ++++++++------ nac3core/src/typecheck/type_inferencer/mod.rs | 112 +++++++++++++----- .../src/typecheck/type_inferencer/test.rs | 22 ++-- nac3core/src/typecheck/typedef/mod.rs | 65 +++++++--- nac3core/src/typecheck/typedef/test.rs | 43 +++++-- 13 files changed, 282 insertions(+), 179 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 6c255800..e420bd80 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -135,7 +135,7 @@ impl Resolver { .collect(); let mut fields_ty = HashMap::new(); for method in methods.iter() { - fields_ty.insert(method.0, method.1); + fields_ty.insert(method.0, (method.1, false)); } for field in fields.iter() { let name: String = field.0.into(); @@ -148,7 +148,7 @@ impl Resolver { // field type mismatch return Ok(None); } - fields_ty.insert(field.0, ty); + fields_ty.insert(field.0, (ty, field.2)); } for (_, ty) in var_map.iter() { // must be concrete type @@ -379,7 +379,7 @@ impl Resolver { if let TopLevelDef::Class { fields, .. } = &*definition { let values: Result>, _> = fields .iter() - .map(|(name, _)| { + .map(|(name, _, _)| { self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx) }) .collect(); @@ -413,7 +413,7 @@ impl SymbolResolver for Resolver { id_to_type.get(&str).cloned().or_else(|| { let py_id = self.name_to_pyid.get(&str); let result = py_id.and_then(|id| { - self.pyid_to_type.read().get(&id).copied().or_else(|| { + self.pyid_to_type.read().get(id).copied().or_else(|| { Python::with_gil(|py| -> PyResult> { let obj: &PyAny = self.module.extract(py)?; let members: &PyList = PyModule::import(py, "inspect")? @@ -491,7 +491,7 @@ impl SymbolResolver for Resolver { let mut id_to_def = self.id_to_def.lock(); id_to_def.get(&id).cloned().or_else(|| { let py_id = self.name_to_pyid.get(&id); - let result = py_id.and_then(|id| self.pyid_to_def.read().get(&id).copied()); + let result = py_id.and_then(|id| self.pyid_to_def.read().get(id).copied()); if let Some(result) = &result { id_to_def.insert(id, *result); } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 3d734a8f..f4f4518b 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -44,7 +44,7 @@ pub enum ConcreteTypeEnum { }, TObj { obj_id: DefinitionId, - fields: HashMap, + fields: HashMap, params: HashMap, }, TVirtual { @@ -148,7 +148,7 @@ impl ConcreteTypeStore { .borrow() .iter() .map(|(name, ty)| { - (*name, self.from_unifier_type(unifier, primitives, *ty, cache)) + (*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)) }) .collect(), params: params @@ -225,7 +225,7 @@ impl ConcreteTypeStore { fields: fields .iter() .map(|(name, cty)| { - (*name, self.to_unifier_type(unifier, primitives, *cty, cache)) + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) }) .collect::>() .into(), diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1ec0cf4e..d64212b6 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -232,7 +232,7 @@ fn get_llvm_type<'ctx>( let fields = fields.borrow(); let fields = fields_list .iter() - .map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0])) + .map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0].0)) .collect_vec(); ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() } else { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 5395a4dd..161d78a9 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -2,16 +2,19 @@ use std::collections::HashMap; use std::fmt::Debug; use std::{cell::RefCell, sync::Arc}; -use crate::{codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}}; use crate::typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, }; +use crate::{ + codegen::CodeGenContext, + toplevel::{DefinitionId, TopLevelDef}, +}; use crate::{location::Location, typecheck::typedef::TypeEnum}; -use itertools::{chain, izip}; -use parking_lot::RwLock; -use nac3parser::ast::{Expr, StrRef}; use inkwell::values::BasicValueEnum; +use itertools::{chain, izip}; +use nac3parser::ast::{Expr, StrRef}; +use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] pub enum SymbolValue { @@ -35,7 +38,11 @@ pub trait SymbolResolver { ) -> Option; // get the top-level definition of identifiers fn get_identifier_def(&self, str: StrRef) -> Option; - fn get_symbol_value<'ctx, 'a>(&self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>) -> Option>; + fn get_symbol_value<'ctx, 'a>( + &self, + str: StrRef, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> Option>; fn get_symbol_location(&self, str: StrRef) -> Option; // handle function call etc. } @@ -62,9 +69,7 @@ pub fn parse_type_annotation( expr: &Expr, ) -> Result { use nac3parser::ast::ExprKind::*; - let ids = IDENTIFIER_ID.with(|ids| { - *ids - }); + let ids = IDENTIFIER_ID.with(|ids| *ids); let int32_id = ids[0]; let int64_id = ids[1]; let float_id = ids[2]; @@ -99,8 +104,8 @@ pub fn parse_type_annotation( } let fields = RefCell::new( chain( - fields.iter().map(|(k, v)| (*k, *v)), - methods.iter().map(|(k, v, _)| (*k, *v)), + fields.iter().map(|(k, v, m)| (*k, (*v, *m))), + methods.iter().map(|(k, v, _)| (*k, (*v, false))), ) .collect(), ); @@ -124,7 +129,7 @@ pub fn parse_type_annotation( } } } - }, + } Subscript { value, slice, .. } => { if let Name { id, .. } = &value.node { if *id == virtual_id { @@ -209,14 +214,14 @@ pub fn parse_type_annotation( } let mut fields = fields .iter() - .map(|(attr, ty)| { + .map(|(attr, ty, is_mutable)| { let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, ty) + (*attr, (ty, *is_mutable)) }) .collect::>(); fields.extend(methods.iter().map(|(attr, ty, _)| { let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, ty) + (*attr, (ty, false)) })); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a3d55f1f..99f67b19 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -461,7 +461,7 @@ impl TopLevelComposer { "str".into(), "self".into(), "Kernel".into(), - "KernelImmutable".into(), + "KernelInvariant".into(), ]); let defined_names: HashSet = Default::default(); let method_class: HashMap = Default::default(); @@ -1391,22 +1391,24 @@ impl TopLevelComposer { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_fresh_var().0; - class_fields_def.push((*attr, dummy_field_type)); - // handle Kernel[T], KernelImmutable[T] - let annotation = { - match &annotation.as_ref().node { - ast::ExprKind::Subscript { value, slice, .. } - if { - matches!(&value.node, ast::ExprKind::Name { id, .. } - if id == &"Kernel".into() || id == &"KernelImmutable".into()) - } => - { - slice + // handle Kernel[T], KernelInvariant[T] + let (annotation, mutable) = { + let mut result = None; + if let ast::ExprKind::Subscript { value, slice, .. } = &annotation.as_ref().node { + if let ast::ExprKind::Name { id, .. } = &value.node { + result = if id == &"Kernel".into() { + Some((slice, true)) + } else if id == &"KernelInvariant".into() { + Some((slice, false)) + } else { + None + } } - _ => annotation, } + result.unwrap_or((annotation, true)) }; + class_fields_def.push((*attr, dummy_field_type, mutable)); let annotation = parse_ast_to_type_annotation_kinds( class_resolver, @@ -1532,26 +1534,13 @@ impl TopLevelComposer { class_methods_def.extend(new_child_methods); // handle class fields - let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new(); + let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new(); // let mut is_override: HashSet<_> = HashSet::new(); - for (anc_field_name, anc_field_ty) in fields { - let to_be_added = (*anc_field_name, *anc_field_ty); + for (anc_field_name, anc_field_ty, mutable) in fields { + let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); // find if there is a fields with the same name in the child class for (class_field_name, ..) in class_fields_def.iter() { if class_field_name == anc_field_name { - // let ok = Self::check_overload_field_type( - // *class_field_ty, - // *anc_field_ty, - // unifier, - // type_var_to_concrete_def, - // ); - // if !ok { - // return Err("fields has same name as ancestors' field, but incompatible type".into()); - // } - // // mark it as added - // is_override.insert(class_field_name.to_string()); - // to_be_added = (class_field_name.to_string(), *class_field_ty); - // break; return Err(format!( "field `{}` has already declared in the ancestor classes", class_field_name @@ -1560,9 +1549,9 @@ impl TopLevelComposer { } new_child_fields.push(to_be_added); } - for (class_field_name, class_field_ty) in class_fields_def.iter() { + for (class_field_name, class_field_ty, mutable) in class_fields_def.iter() { if !is_override.contains(class_field_name) { - new_child_fields.push((*class_field_name, *class_field_ty)); + new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); } } class_fields_def.drain(..); @@ -1636,7 +1625,7 @@ impl TopLevelComposer { unreachable!("must be init function here") } let all_inited = Self::get_all_assigned_field(body.as_slice())?; - if fields.iter().any(|(x, _)| !all_inited.contains(x)) { + if fields.iter().any(|x| !all_inited.contains(&x.0)) { return Err(format!( "fields of class {} not fully initialized", class_name diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index a3869943..9c2a493e 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -12,7 +12,7 @@ impl TopLevelDef { } => { let fields_str = fields .iter() - .map(|(n, ty)| { + .map(|(n, ty, _)| { (n.to_string(), unifier.default_stringify(*ty)) }) .collect_vec(); diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index f8057112..c96a7cac 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -85,7 +85,8 @@ pub enum TopLevelDef { /// type variables bounded to the class. type_vars: Vec, // class fields - fields: Vec<(StrRef, Type)>, + // name, type, is mutable + fields: Vec<(StrRef, Type, bool)>, // class methods, pointing to the corresponding function definition. methods: Vec<(StrRef, Type, DefinitionId)>, // ancestor classes, including itself. diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 8db0951e..7cb00362 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -307,25 +307,15 @@ pub fn get_type_from_type_annotation_kinds( .iter() .map(|(name, ty, _)| { let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*name, subst_ty) + // methods are immutable + (*name, (subst_ty, false)) }) - .collect::>(); - tobj_fields.extend(fields.iter().map(|(name, ty)| { + .collect::>(); + tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*name, subst_ty) + (*name, (subst_ty, *mutability)) })); - // println!("tobj_fields: {:?}", tobj_fields); - // println!( - // "{:?}: {}\n", - // tobj_fields.get("__init__").unwrap(), - // unifier.stringify( - // *tobj_fields.get("__init__").unwrap(), - // &mut |id| format!("class{}", id), - // &mut |id| format!("tvar{}", id) - // ) - // ); - Ok(unifier.add_ty(TypeEnum::TObj { obj_id: *obj_id, fields: RefCell::new(tobj_fields), diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 0dcc0ea9..43d0da59 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -86,33 +86,39 @@ pub fn impl_binop( }; for op in ops { fields.borrow_mut().insert(binop_name(op).into(), { - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: ret_ty, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - .into(), - )) + ( + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: ret_ty, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )), + false, + ) }); fields.borrow_mut().insert(binop_assign_name(op).into(), { - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: store.none, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - .into(), - )) + ( + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: store.none, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )), + false, + ) }); } } else { @@ -131,9 +137,12 @@ pub fn impl_unaryop( for op in ops { fields.borrow_mut().insert( unaryop_name(op).into(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(), - )), + ( + unifier.add_ty(TypeEnum::TFunc( + FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(), + )), + false, + ), ); } } else { @@ -152,18 +161,21 @@ pub fn impl_cmpop( for op in ops { fields.borrow_mut().insert( comparison_name(op).unwrap().into(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: store.bool, - vars: HashMap::new(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - .into(), - )), + ( + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: store.bool, + vars: HashMap::new(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )), + false, + ), ); } } else { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 26f544d8..c0c5c404 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -10,7 +10,7 @@ use itertools::izip; use nac3parser::ast::{ self, fold::{self, Fold}, - Arguments, Comprehension, ExprKind, Located, Location, StrRef, + Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, }; #[cfg(test)] @@ -77,11 +77,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { Ok(None) } - fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result, Self::Error> { + fn fold_stmt( + &mut self, + mut node: ast::Stmt<()>, + ) -> Result, Self::Error> { let stmt = match node.node { // we don't want fold over type annotation - ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment } => { + ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => { self.infer_pattern(&target)?; + // fix parser problem... + if let ExprKind::Attribute { ctx, .. } = &mut target.node { + *ctx = ExprContext::Store; + } + let target = Box::new(self.fold_expr(*target)?); let value = if let Some(v) = value { let ty = Box::new(self.fold_expr(*v)?); @@ -105,14 +113,25 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { Located { location: node.location, custom: None, - node: ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment }, + node: ast::StmtKind::AnnAssign { + target, + annotation, + value, + simple, + config_comment, + }, } } ast::StmtKind::For { ref target, .. } => { self.infer_pattern(target)?; fold::fold_stmt(self, node)? } - ast::StmtKind::Assign { ref targets, ref config_comment, .. } => { + ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { + for target in targets.iter_mut() { + if let ExprKind::Attribute { ctx, .. } = &mut target.node { + *ctx = ExprContext::Store; + } + } if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) { if let ast::StmtKind::Assign { targets, value, .. } = node.node { let value = self.fold_expr(*value)?; @@ -158,7 +177,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { targets, value: Box::new(value), type_comment: None, - config_comment: config_comment.clone() + config_comment: config_comment.clone(), }, custom: None, }); @@ -199,7 +218,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} - ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Pass { .. } => {} + ast::StmtKind::Break { .. } + | ast::StmtKind::Continue { .. } + | ast::StmtKind::Pass { .. } => {} ast::StmtKind::With { items, .. } => { for item in items.iter() { let ty = item.context_expr.custom.unwrap(); @@ -209,17 +230,21 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { let fields = fields.borrow(); fast_path = true; if let Some(enter) = fields.get(&"__enter__".into()).cloned() { - if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) { + if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) { let signature = signature.borrow(); if !signature.args.is_empty() { return report_error( "__enter__ method should take no argument other than self", - stmt.location - ) + stmt.location, + ); } if let Some(var) = &item.optional_vars { if signature.vars.is_empty() { - self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?; + self.unify( + signature.ret, + var.custom.unwrap(), + &stmt.location, + )?; } else { fast_path = false; } @@ -230,17 +255,17 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } else { return report_error( "__enter__ method is required for context manager", - stmt.location + stmt.location, ); } if let Some(exit) = fields.get(&"__exit__".into()).cloned() { - if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) { + if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) { let signature = signature.borrow(); if !signature.args.is_empty() { return report_error( "__exit__ method should take no argument other than self", - stmt.location - ) + stmt.location, + ); } } else { fast_path = false; @@ -248,26 +273,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } else { return report_error( "__exit__ method is required for context manager", - stmt.location + stmt.location, ); } } if !fast_path { let enter = TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![], - ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()), - vars: Default::default() + ret: item.optional_vars.as_ref().map_or_else( + || self.unifier.get_fresh_var().0, + |var| var.custom.unwrap(), + ), + vars: Default::default(), })); let enter = self.unifier.add_ty(enter); let exit = TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![], ret: self.unifier.get_fresh_var().0, - vars: Default::default() + vars: Default::default(), })); let exit = self.unifier.add_ty(exit); let mut fields = HashMap::new(); - fields.insert("__enter__".into(), enter); - fields.insert("__exit__".into(), exit); + fields.insert("__enter__".into(), (enter, false)); + fields.insert("__exit__".into(), (exit, false)); let record = self.unifier.add_record(fields); self.unify(ty, record, &stmt.location)?; } @@ -330,8 +358,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), - ast::ExprKind::Attribute { value, attr, ctx: _ } => { - Some(self.infer_attribute(value, *attr)?) + ast::ExprKind::Attribute { value, attr, ctx } => { + Some(self.infer_attribute(value, *attr, ctx)?) } ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BinOp { left, op, right } => { @@ -399,7 +427,8 @@ impl<'a> Inferencer<'a> { if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { if class_params.borrow().is_empty() { if let Some(ty) = fields.borrow().get(&method) { - if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) { + let ty = ty.0; + if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { let sign = sign.borrow(); if sign.vars.is_empty() { let call = Call { @@ -419,7 +448,7 @@ impl<'a> Inferencer<'a> { .rev() .collect(); self.unifier - .unify_call(&call, *ty, &sign, &required) + .unify_call(&call, ty, &sign, &required) .map_err(|old| format!("{} at {}", old, location))?; return Ok(sign.ret); } @@ -437,7 +466,7 @@ impl<'a> Inferencer<'a> { }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); - let fields = once((method, call)).collect(); + let fields = once((method, (call, false))).collect(); let record = self.unifier.add_record(fields); self.constrain(obj, record, &location)?; Ok(ret) @@ -538,7 +567,11 @@ impl<'a> Inferencer<'a> { let target = new_context.fold_expr(*generator.target)?; let iter = new_context.fold_expr(*generator.iter)?; if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) { - new_context.unify(target.custom.unwrap(), new_context.primitives.int32, &target.location)?; + new_context.unify( + target.custom.unwrap(), + new_context.primitives.int32, + &target.location, + )?; } else { let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); new_context.unify(iter.custom.unwrap(), list, &iter.location)?; @@ -755,12 +788,27 @@ impl<'a> Inferencer<'a> { &mut self, value: &ast::Expr>, attr: StrRef, + ctx: &ExprContext, ) -> InferenceResult { - let (attr_ty, _) = self.unifier.get_fresh_var(); - let fields = once((attr, attr_ty)).collect(); - let record = self.unifier.add_record(fields); - self.constrain(value.custom.unwrap(), record, &value.location)?; - Ok(attr_ty) + let ty = value.custom.unwrap(); + if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { + // just a fast path + let fields = fields.borrow(); + match (fields.get(&attr), ctx == &ExprContext::Store) { + (Some((ty, true)), _) => Ok(*ty), + (Some((ty, false)), false) => Ok(*ty), + (Some((_, false)), true) => { + report_error(&format!("Field {} should be immutable", attr), value.location) + } + (None, _) => report_error(&format!("No such field {}", attr), value.location), + } + } else { + let (attr_ty, _) = self.unifier.get_fresh_var(); + let fields = once((attr, (attr_ty, ctx == &ExprContext::Store))).collect(); + let record = self.unifier.add_record(fields); + self.constrain(value.custom.unwrap(), record, &value.location)?; + Ok(attr_ty) + } } fn infer_bool_ops(&mut self, values: &[ast::Expr>]) -> InferenceResult { diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 182116d4..d3aa09fa 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -8,8 +8,8 @@ use crate::{ use indoc::indoc; use inkwell::values::BasicValueEnum; use itertools::zip; -use parking_lot::RwLock; use nac3parser::parser::parse_program; +use parking_lot::RwLock; use test_case::test_case; struct Resolver { @@ -75,7 +75,7 @@ impl TestEnvironment { } .into(), )); - fields.borrow_mut().insert("__add__".into(), add_ty); + fields.borrow_mut().insert("__add__".into(), (add_ty, false)); } let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), @@ -170,7 +170,7 @@ impl TestEnvironment { } .into(), )); - fields.borrow_mut().insert("__add__".into(), add_ty); + fields.borrow_mut().insert("__add__".into(), (add_ty, false)); } let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), @@ -203,7 +203,9 @@ impl TestEnvironment { params: HashMap::new().into(), }); identifier_mapping.insert("None".into(), none); - for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate() { + for (i, name) in + ["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate() + { top_level_defs.push( RwLock::new(TopLevelDef::Class { name: (*name).into(), @@ -225,7 +227,7 @@ impl TestEnvironment { let foo_ty = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(7), - fields: [("a".into(), v0)].iter().cloned().collect::>().into(), + fields: [("a".into(), (v0, true))].iter().cloned().collect::>().into(), params: [(id, v0)].iter().cloned().collect::>().into(), }); top_level_defs.push( @@ -233,7 +235,7 @@ impl TestEnvironment { name: "Foo".into(), object_id: DefinitionId(7), type_vars: vec![v0], - fields: [("a".into(), v0)].into(), + fields: [("a".into(), v0, true)].into(), methods: Default::default(), ancestors: Default::default(), resolver: None, @@ -259,7 +261,7 @@ impl TestEnvironment { )); let bar = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(8), - fields: [("a".into(), int32), ("b".into(), fun)] + fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))] .iter() .cloned() .collect::>() @@ -271,7 +273,7 @@ impl TestEnvironment { name: "Bar".into(), object_id: DefinitionId(8), type_vars: Default::default(), - fields: [("a".into(), int32), ("b".into(), fun)].into(), + fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(), methods: Default::default(), ancestors: Default::default(), resolver: None, @@ -288,7 +290,7 @@ impl TestEnvironment { let bar2 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(9), - fields: [("a".into(), bool), ("b".into(), fun)] + fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))] .iter() .cloned() .collect::>() @@ -300,7 +302,7 @@ impl TestEnvironment { name: "Bar2".into(), object_id: DefinitionId(9), type_vars: Default::default(), - fields: [("a".into(), bool), ("b".into(), fun)].into(), + fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(), methods: Default::default(), ancestors: Default::default(), resolver: None, diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 5692723a..eb0cf4bc 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -49,7 +49,7 @@ pub struct FunSignature { pub enum TypeVarMeta { Generic, Sequence(RefCell>), - Record(RefCell>), + Record(RefCell>), } #[derive(Clone)] @@ -71,7 +71,7 @@ pub enum TypeEnum { }, TObj { obj_id: DefinitionId, - fields: RefCell>, + fields: RefCell>, params: RefCell, }, TVirtual { @@ -155,7 +155,7 @@ impl Unifier { self.unification_table.new_key(Rc::new(a)) } - pub fn add_record(&mut self, fields: Mapping) -> Type { + pub fn add_record(&mut self, fields: Mapping) -> Type { let id = self.var_id + 1; self.var_id += 1; self.add_ty(TypeEnum::TVar { @@ -394,11 +394,12 @@ impl Unifier { } (Record(fields1), Record(fields2)) => { let mut fields2 = fields2.borrow_mut(); - for (key, value) in fields1.borrow().iter() { - if let Some(ty) = fields2.get(key) { - self.unify_impl(*ty, *value, false)?; + for (key, (ty, is_mutable)) in fields1.borrow().iter() { + if let Some((ty2, is_mutable2)) = fields2.get_mut(key) { + self.unify_impl(*ty2, *ty, false)?; + *is_mutable2 |= *is_mutable; } else { - fields2.insert(*key, *value); + fields2.insert(*key, (*ty, *is_mutable)); } } } @@ -495,13 +496,19 @@ impl Unifier { self.set_a_to_b(a, b); } (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { - for (k, v) in map.borrow().iter() { - let ty = fields + for (k, (ty, is_mutable)) in map.borrow().iter() { + let (ty2, is_mutable2) = fields .borrow() .get(k) .copied() .ok_or_else(|| format!("No such attribute {}", k))?; - self.unify_impl(ty, *v, false)?; + // typevar represents the usage of the variable + // it is OK to have immutable usage for mutable fields + // but cannot have mutable usage for immutable fields + if *is_mutable && !is_mutable2 { + return Err(format!("Field {} should be immutable", k)); + } + self.unify_impl(*ty, ty2, false)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); self.unify_impl(x, b, false)?; @@ -510,16 +517,19 @@ impl Unifier { (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { let ty = self.get_ty(*ty); if let TObj { fields, .. } = ty.as_ref() { - for (k, v) in map.borrow().iter() { - let ty = fields + for (k, (ty, is_mutable)) in map.borrow().iter() { + let (ty2, is_mutable2) = fields .borrow() .get(k) .copied() .ok_or_else(|| format!("No such attribute {}", k))?; - if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { + if !matches!(self.get_ty(ty2).as_ref(), TFunc { .. }) { return Err(format!("Cannot access field {} for virtual type", k)); } - self.unify_impl(*v, ty, false)?; + if *is_mutable && !is_mutable2 { + return Err(format!("Field {} should be immutable", k)); + } + self.unify_impl(*ty, ty2, false)?; } } else { // require annotation... @@ -643,7 +653,9 @@ impl Unifier { let fields = fields .borrow() .iter() - .map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))) + .map(|(k, (v, _))| { + format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)) + }) .join(", "); format!("record[{}]", fields) } @@ -805,7 +817,7 @@ impl Unifier { let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone()); let fields = self - .subst_map(&fields.borrow(), mapping, cache) + .subst_map2(&fields.borrow(), mapping, cache) .unwrap_or_else(|| fields.borrow().clone()); let new_ty = self.add_ty(TypeEnum::TObj { obj_id, @@ -873,6 +885,27 @@ impl Unifier { map2 } + fn subst_map2( + &mut self, + map: &Mapping, + mapping: &VarMap, + cache: &mut HashMap>, + ) -> Option> + where + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, + { + let mut map2 = None; + for (k, (v, mutability)) in map.iter() { + if let Some(v1) = self.subst_impl(*v, mapping, cache) { + if map2.is_none() { + map2 = Some(map.clone()); + } + *map2.as_mut().unwrap().get_mut(k).unwrap() = (v1, *mutability); + } + } + map2 + } + fn get_intersection(&mut self, a: Type, b: Type) -> Result, ()> { use TypeEnum::*; let x = self.get_ty(a); diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 064f2f9f..2e91b659 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -39,7 +39,7 @@ impl Unifier { ( TypeEnum::TVar { meta: Record(fields1), .. }, TypeEnum::TVar { meta: Record(fields2), .. }, - ) => self.map_eq(&fields1.borrow(), &fields2.borrow()), + ) => self.map_eq2(&fields1.borrow(), &fields2.borrow()), ( TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. }, @@ -63,6 +63,25 @@ impl Unifier { } true } + + fn map_eq2( + &mut self, + map1: &Mapping, + map2: &Mapping, + ) -> bool + where + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, + { + if map1.len() != map2.len() { + return false; + } + for (k, (ty1, m1)) in map1.iter() { + if !map2.get(k).map(|(ty2, m2)| m1 == m2 && self.eq(*ty1, *ty2)).unwrap_or(false) { + return false; + } + } + true + } } struct TestEnvironment { @@ -104,7 +123,11 @@ impl TestEnvironment { "Foo".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: [("a".into(), v0)].iter().cloned().collect::>().into(), + fields: [("a".into(), (v0, true))] + .iter() + .cloned() + .collect::>() + .into(), params: [(id, v0)].iter().cloned().collect::>().into(), }), ); @@ -151,7 +174,7 @@ impl TestEnvironment { let eq = s.find('=').unwrap(); let key = s[1..eq].into(); let result = self.internal_parse(&s[eq + 1..], mapping); - fields.insert(key, result.0); + fields.insert(key, (result.0, true)); s = result.1; } (self.unifier.add_record(fields), &s[1..]) @@ -326,7 +349,7 @@ fn test_recursive_subst() { let foo_ty = env.unifier.get_ty(foo_id); let mapping: HashMap<_, _>; if let TypeEnum::TObj { fields, params, .. } = &*foo_ty { - fields.borrow_mut().insert("rec".into(), foo_id); + fields.borrow_mut().insert("rec".into(), (foo_id, true)); mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect(); } else { unreachable!() @@ -335,8 +358,8 @@ fn test_recursive_subst() { let instantiated_ty = env.unifier.get_ty(instantiated); if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { let fields = fields.borrow(); - assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int)); - assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated)); + assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); + assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); } else { unreachable!() } @@ -351,7 +374,7 @@ fn test_virtual() { )); let bar = env.unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), - fields: [("f".into(), fun), ("a".into(), int)] + fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] .iter() .cloned() .collect::>() @@ -363,15 +386,15 @@ fn test_virtual() { let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); - let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect()); + let c = env.unifier.add_record([("f".into(), (v1, false))].iter().cloned().collect()); env.unifier.unify(a, b).unwrap(); env.unifier.unify(b, c).unwrap(); assert!(env.unifier.eq(v1, fun)); - let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect()); + let d = env.unifier.add_record([("a".into(), (v1, true))].iter().cloned().collect()); assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); - let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect()); + let d = env.unifier.add_record([("b".into(), (v1, true))].iter().cloned().collect()); assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); }