1
0
forked from M-Labs/nac3

nac3core: type check invariants

This rejects code that tries to assign to KernelInvariant fields and
methods.
This commit is contained in:
pca006132 2021-11-06 22:48:08 +08:00
parent 7385b91113
commit b1e83a1fd4
13 changed files with 282 additions and 179 deletions

View File

@ -135,7 +135,7 @@ impl Resolver {
.collect();
let mut fields_ty = HashMap::new();
for method in methods.iter() {
fields_ty.insert(method.0, method.1);
fields_ty.insert(method.0, (method.1, false));
}
for field in fields.iter() {
let name: String = field.0.into();
@ -148,7 +148,7 @@ impl Resolver {
// field type mismatch
return Ok(None);
}
fields_ty.insert(field.0, ty);
fields_ty.insert(field.0, (ty, field.2));
}
for (_, ty) in var_map.iter() {
// must be concrete type
@ -379,7 +379,7 @@ impl Resolver {
if let TopLevelDef::Class { fields, .. } = &*definition {
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, _)| {
.map(|(name, _, _)| {
self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx)
})
.collect();
@ -413,7 +413,7 @@ impl SymbolResolver for Resolver {
id_to_type.get(&str).cloned().or_else(|| {
let py_id = self.name_to_pyid.get(&str);
let result = py_id.and_then(|id| {
self.pyid_to_type.read().get(&id).copied().or_else(|| {
self.pyid_to_type.read().get(id).copied().or_else(|| {
Python::with_gil(|py| -> PyResult<Option<Type>> {
let obj: &PyAny = self.module.extract(py)?;
let members: &PyList = PyModule::import(py, "inspect")?
@ -491,7 +491,7 @@ impl SymbolResolver for Resolver {
let mut id_to_def = self.id_to_def.lock();
id_to_def.get(&id).cloned().or_else(|| {
let py_id = self.name_to_pyid.get(&id);
let result = py_id.and_then(|id| self.pyid_to_def.read().get(&id).copied());
let result = py_id.and_then(|id| self.pyid_to_def.read().get(id).copied());
if let Some(result) = &result {
id_to_def.insert(id, *result);
}

View File

@ -44,7 +44,7 @@ pub enum ConcreteTypeEnum {
},
TObj {
obj_id: DefinitionId,
fields: HashMap<StrRef, ConcreteType>,
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>,
},
TVirtual {
@ -148,7 +148,7 @@ impl ConcreteTypeStore {
.borrow()
.iter()
.map(|(name, ty)| {
(*name, self.from_unifier_type(unifier, primitives, *ty, cache))
(*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))
})
.collect(),
params: params
@ -225,7 +225,7 @@ impl ConcreteTypeStore {
fields: fields
.iter()
.map(|(name, cty)| {
(*name, self.to_unifier_type(unifier, primitives, *cty, cache))
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>()
.into(),

View File

@ -232,7 +232,7 @@ fn get_llvm_type<'ctx>(
let fields = fields.borrow();
let fields = fields_list
.iter()
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0].0))
.collect_vec();
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
} else {

View File

@ -2,16 +2,19 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::{cell::RefCell, sync::Arc};
use crate::{codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}};
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, Unifier},
};
use crate::{
codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef},
};
use crate::{location::Location, typecheck::typedef::TypeEnum};
use itertools::{chain, izip};
use parking_lot::RwLock;
use nac3parser::ast::{Expr, StrRef};
use inkwell::values::BasicValueEnum;
use itertools::{chain, izip};
use nac3parser::ast::{Expr, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
@ -35,7 +38,11 @@ pub trait SymbolResolver {
) -> Option<Type>;
// get the top-level definition of identifiers
fn get_identifier_def(&self, str: StrRef) -> Option<DefinitionId>;
fn get_symbol_value<'ctx, 'a>(&self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>) -> Option<BasicValueEnum<'ctx>>;
fn get_symbol_value<'ctx, 'a>(
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, 'a>,
) -> Option<BasicValueEnum<'ctx>>;
fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
// handle function call etc.
}
@ -62,9 +69,7 @@ pub fn parse_type_annotation<T>(
expr: &Expr<T>,
) -> Result<Type, String> {
use nac3parser::ast::ExprKind::*;
let ids = IDENTIFIER_ID.with(|ids| {
*ids
});
let ids = IDENTIFIER_ID.with(|ids| *ids);
let int32_id = ids[0];
let int64_id = ids[1];
let float_id = ids[2];
@ -99,8 +104,8 @@ pub fn parse_type_annotation<T>(
}
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v)| (*k, *v)),
methods.iter().map(|(k, v, _)| (*k, *v)),
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect(),
);
@ -124,7 +129,7 @@ pub fn parse_type_annotation<T>(
}
}
}
},
}
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
if *id == virtual_id {
@ -209,14 +214,14 @@ pub fn parse_type_annotation<T>(
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
.map(|(attr, ty, is_mutable)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, ty)
(*attr, (ty, *is_mutable))
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, ty)
(*attr, (ty, false))
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,

View File

@ -461,7 +461,7 @@ impl TopLevelComposer {
"str".into(),
"self".into(),
"Kernel".into(),
"KernelImmutable".into(),
"KernelInvariant".into(),
]);
let defined_names: HashSet<String> = Default::default();
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
@ -1391,22 +1391,24 @@ impl TopLevelComposer {
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_fresh_var().0;
class_fields_def.push((*attr, dummy_field_type));
// handle Kernel[T], KernelImmutable[T]
let annotation = {
match &annotation.as_ref().node {
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. }
if id == &"Kernel".into() || id == &"KernelImmutable".into())
} =>
{
slice
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = {
let mut result = None;
if let ast::ExprKind::Subscript { value, slice, .. } = &annotation.as_ref().node {
if let ast::ExprKind::Name { id, .. } = &value.node {
result = if id == &"Kernel".into() {
Some((slice, true))
} else if id == &"KernelInvariant".into() {
Some((slice, false))
} else {
None
}
_ => annotation,
}
}
result.unwrap_or((annotation, true))
};
class_fields_def.push((*attr, dummy_field_type, mutable));
let annotation = parse_ast_to_type_annotation_kinds(
class_resolver,
@ -1532,26 +1534,13 @@ impl TopLevelComposer {
class_methods_def.extend(new_child_methods);
// handle class fields
let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new();
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
// let mut is_override: HashSet<_> = HashSet::new();
for (anc_field_name, anc_field_ty) in fields {
let to_be_added = (*anc_field_name, *anc_field_ty);
for (anc_field_name, anc_field_ty, mutable) in fields {
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
// find if there is a fields with the same name in the child class
for (class_field_name, ..) in class_fields_def.iter() {
if class_field_name == anc_field_name {
// let ok = Self::check_overload_field_type(
// *class_field_ty,
// *anc_field_ty,
// unifier,
// type_var_to_concrete_def,
// );
// if !ok {
// return Err("fields has same name as ancestors' field, but incompatible type".into());
// }
// // mark it as added
// is_override.insert(class_field_name.to_string());
// to_be_added = (class_field_name.to_string(), *class_field_ty);
// break;
return Err(format!(
"field `{}` has already declared in the ancestor classes",
class_field_name
@ -1560,9 +1549,9 @@ impl TopLevelComposer {
}
new_child_fields.push(to_be_added);
}
for (class_field_name, class_field_ty) in class_fields_def.iter() {
for (class_field_name, class_field_ty, mutable) in class_fields_def.iter() {
if !is_override.contains(class_field_name) {
new_child_fields.push((*class_field_name, *class_field_ty));
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
}
}
class_fields_def.drain(..);
@ -1636,7 +1625,7 @@ impl TopLevelComposer {
unreachable!("must be init function here")
}
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
if fields.iter().any(|x| !all_inited.contains(&x.0)) {
return Err(format!(
"fields of class {} not fully initialized",
class_name

View File

@ -12,7 +12,7 @@ impl TopLevelDef {
} => {
let fields_str = fields
.iter()
.map(|(n, ty)| {
.map(|(n, ty, _)| {
(n.to_string(), unifier.default_stringify(*ty))
})
.collect_vec();

View File

@ -85,7 +85,8 @@ pub enum TopLevelDef {
/// type variables bounded to the class.
type_vars: Vec<Type>,
// class fields
fields: Vec<(StrRef, Type)>,
// name, type, is mutable
fields: Vec<(StrRef, Type, bool)>,
// class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
// ancestor classes, including itself.

View File

@ -307,25 +307,15 @@ pub fn get_type_from_type_annotation_kinds(
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, subst_ty)
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, Type>>();
tobj_fields.extend(fields.iter().map(|(name, ty)| {
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, subst_ty)
(*name, (subst_ty, *mutability))
}));
// println!("tobj_fields: {:?}", tobj_fields);
// println!(
// "{:?}: {}\n",
// tobj_fields.get("__init__").unwrap(),
// unifier.stringify(
// *tobj_fields.get("__init__").unwrap(),
// &mut |id| format!("class{}", id),
// &mut |id| format!("tvar{}", id)
// )
// );
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: RefCell::new(tobj_fields),

View File

@ -86,6 +86,7 @@ pub fn impl_binop(
};
for op in ops {
fields.borrow_mut().insert(binop_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: ret_ty,
@ -97,10 +98,13 @@ pub fn impl_binop(
}],
}
.into(),
))
)),
false,
)
});
fields.borrow_mut().insert(binop_assign_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.none,
@ -112,7 +116,9 @@ pub fn impl_binop(
}],
}
.into(),
))
)),
false,
)
});
}
} else {
@ -131,9 +137,12 @@ pub fn impl_unaryop(
for op in ops {
fields.borrow_mut().insert(
unaryop_name(op).into(),
(
unifier.add_ty(TypeEnum::TFunc(
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
)),
false,
),
);
}
} else {
@ -152,6 +161,7 @@ pub fn impl_cmpop(
for op in ops {
fields.borrow_mut().insert(
comparison_name(op).unwrap().into(),
(
unifier.add_ty(TypeEnum::TFunc(
FunSignature {
ret: store.bool,
@ -164,6 +174,8 @@ pub fn impl_cmpop(
}
.into(),
)),
false,
),
);
}
} else {

View File

@ -10,7 +10,7 @@ use itertools::izip;
use nac3parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprKind, Located, Location, StrRef,
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
};
#[cfg(test)]
@ -77,11 +77,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
Ok(None)
}
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
fn fold_stmt(
&mut self,
mut node: ast::Stmt<()>,
) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
let stmt = match node.node {
// we don't want fold over type annotation
ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment } => {
ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => {
self.infer_pattern(&target)?;
// fix parser problem...
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
*ctx = ExprContext::Store;
}
let target = Box::new(self.fold_expr(*target)?);
let value = if let Some(v) = value {
let ty = Box::new(self.fold_expr(*v)?);
@ -105,14 +113,25 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
Located {
location: node.location,
custom: None,
node: ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment },
node: ast::StmtKind::AnnAssign {
target,
annotation,
value,
simple,
config_comment,
},
}
}
ast::StmtKind::For { ref target, .. } => {
self.infer_pattern(target)?;
fold::fold_stmt(self, node)?
}
ast::StmtKind::Assign { ref targets, ref config_comment, .. } => {
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
for target in targets.iter_mut() {
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
*ctx = ExprContext::Store;
}
}
if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) {
if let ast::StmtKind::Assign { targets, value, .. } = node.node {
let value = self.fold_expr(*value)?;
@ -158,7 +177,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
targets,
value: Box::new(value),
type_comment: None,
config_comment: config_comment.clone()
config_comment: config_comment.clone(),
},
custom: None,
});
@ -199,7 +218,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}
}
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Pass { .. } => {}
ast::StmtKind::Break { .. }
| ast::StmtKind::Continue { .. }
| ast::StmtKind::Pass { .. } => {}
ast::StmtKind::With { items, .. } => {
for item in items.iter() {
let ty = item.context_expr.custom.unwrap();
@ -209,17 +230,21 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
let fields = fields.borrow();
fast_path = true;
if let Some(enter) = fields.get(&"__enter__".into()).cloned() {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) {
let signature = signature.borrow();
if !signature.args.is_empty() {
return report_error(
"__enter__ method should take no argument other than self",
stmt.location
)
stmt.location,
);
}
if let Some(var) = &item.optional_vars {
if signature.vars.is_empty() {
self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?;
self.unify(
signature.ret,
var.custom.unwrap(),
&stmt.location,
)?;
} else {
fast_path = false;
}
@ -230,17 +255,17 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} else {
return report_error(
"__enter__ method is required for context manager",
stmt.location
stmt.location,
);
}
if let Some(exit) = fields.get(&"__exit__".into()).cloned() {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) {
let signature = signature.borrow();
if !signature.args.is_empty() {
return report_error(
"__exit__ method should take no argument other than self",
stmt.location
)
stmt.location,
);
}
} else {
fast_path = false;
@ -248,26 +273,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} else {
return report_error(
"__exit__ method is required for context manager",
stmt.location
stmt.location,
);
}
}
if !fast_path {
let enter = TypeEnum::TFunc(RefCell::new(FunSignature {
args: vec![],
ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()),
vars: Default::default()
ret: item.optional_vars.as_ref().map_or_else(
|| self.unifier.get_fresh_var().0,
|var| var.custom.unwrap(),
),
vars: Default::default(),
}));
let enter = self.unifier.add_ty(enter);
let exit = TypeEnum::TFunc(RefCell::new(FunSignature {
args: vec![],
ret: self.unifier.get_fresh_var().0,
vars: Default::default()
vars: Default::default(),
}));
let exit = self.unifier.add_ty(exit);
let mut fields = HashMap::new();
fields.insert("__enter__".into(), enter);
fields.insert("__exit__".into(), exit);
fields.insert("__enter__".into(), (enter, false));
fields.insert("__exit__".into(), (exit, false));
let record = self.unifier.add_record(fields);
self.unify(ty, record, &stmt.location)?;
}
@ -330,8 +358,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
Some(self.infer_attribute(value, *attr)?)
ast::ExprKind::Attribute { value, attr, ctx } => {
Some(self.infer_attribute(value, *attr, ctx)?)
}
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ast::ExprKind::BinOp { left, op, right } => {
@ -399,7 +427,8 @@ impl<'a> Inferencer<'a> {
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
if class_params.borrow().is_empty() {
if let Some(ty) = fields.borrow().get(&method) {
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) {
let ty = ty.0;
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
let sign = sign.borrow();
if sign.vars.is_empty() {
let call = Call {
@ -419,7 +448,7 @@ impl<'a> Inferencer<'a> {
.rev()
.collect();
self.unifier
.unify_call(&call, *ty, &sign, &required)
.unify_call(&call, ty, &sign, &required)
.map_err(|old| format!("{} at {}", old, location))?;
return Ok(sign.ret);
}
@ -437,7 +466,7 @@ impl<'a> Inferencer<'a> {
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect();
let fields = once((method, (call, false))).collect();
let record = self.unifier.add_record(fields);
self.constrain(obj, record, &location)?;
Ok(ret)
@ -538,7 +567,11 @@ impl<'a> Inferencer<'a> {
let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?;
if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) {
new_context.unify(target.custom.unwrap(), new_context.primitives.int32, &target.location)?;
new_context.unify(
target.custom.unwrap(),
new_context.primitives.int32,
&target.location,
)?;
} else {
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
@ -755,13 +788,28 @@ impl<'a> Inferencer<'a> {
&mut self,
value: &ast::Expr<Option<Type>>,
attr: StrRef,
ctx: &ExprContext,
) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
let fields = fields.borrow();
match (fields.get(&attr), ctx == &ExprContext::Store) {
(Some((ty, true)), _) => Ok(*ty),
(Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => {
report_error(&format!("Field {} should be immutable", attr), value.location)
}
(None, _) => report_error(&format!("No such field {}", attr), value.location),
}
} else {
let (attr_ty, _) = self.unifier.get_fresh_var();
let fields = once((attr, attr_ty)).collect();
let fields = once((attr, (attr_ty, ctx == &ExprContext::Store))).collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty)
}
}
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let b = self.primitives.bool;

View File

@ -8,8 +8,8 @@ use crate::{
use indoc::indoc;
use inkwell::values::BasicValueEnum;
use itertools::zip;
use parking_lot::RwLock;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use test_case::test_case;
struct Resolver {
@ -75,7 +75,7 @@ impl TestEnvironment {
}
.into(),
));
fields.borrow_mut().insert("__add__".into(), add_ty);
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
}
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
@ -170,7 +170,7 @@ impl TestEnvironment {
}
.into(),
));
fields.borrow_mut().insert("__add__".into(), add_ty);
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
}
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
@ -203,7 +203,9 @@ impl TestEnvironment {
params: HashMap::new().into(),
});
identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate() {
for (i, name) in
["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate()
{
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: (*name).into(),
@ -225,7 +227,7 @@ impl TestEnvironment {
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
});
top_level_defs.push(
@ -233,7 +235,7 @@ impl TestEnvironment {
name: "Foo".into(),
object_id: DefinitionId(7),
type_vars: vec![v0],
fields: [("a".into(), v0)].into(),
fields: [("a".into(), v0, true)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
@ -259,7 +261,7 @@ impl TestEnvironment {
));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8),
fields: [("a".into(), int32), ("b".into(), fun)]
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
@ -271,7 +273,7 @@ impl TestEnvironment {
name: "Bar".into(),
object_id: DefinitionId(8),
type_vars: Default::default(),
fields: [("a".into(), int32), ("b".into(), fun)].into(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
@ -288,7 +290,7 @@ impl TestEnvironment {
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9),
fields: [("a".into(), bool), ("b".into(), fun)]
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
@ -300,7 +302,7 @@ impl TestEnvironment {
name: "Bar2".into(),
object_id: DefinitionId(9),
type_vars: Default::default(),
fields: [("a".into(), bool), ("b".into(), fun)].into(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,

View File

@ -49,7 +49,7 @@ pub struct FunSignature {
pub enum TypeVarMeta {
Generic,
Sequence(RefCell<Mapping<i32>>),
Record(RefCell<Mapping<StrRef>>),
Record(RefCell<Mapping<StrRef, (Type, bool)>>),
}
#[derive(Clone)]
@ -71,7 +71,7 @@ pub enum TypeEnum {
},
TObj {
obj_id: DefinitionId,
fields: RefCell<Mapping<StrRef>>,
fields: RefCell<Mapping<StrRef, (Type, bool)>>,
params: RefCell<VarMap>,
},
TVirtual {
@ -155,7 +155,7 @@ impl Unifier {
self.unification_table.new_key(Rc::new(a))
}
pub fn add_record(&mut self, fields: Mapping<StrRef>) -> Type {
pub fn add_record(&mut self, fields: Mapping<StrRef, (Type, bool)>) -> Type {
let id = self.var_id + 1;
self.var_id += 1;
self.add_ty(TypeEnum::TVar {
@ -394,11 +394,12 @@ impl Unifier {
}
(Record(fields1), Record(fields2)) => {
let mut fields2 = fields2.borrow_mut();
for (key, value) in fields1.borrow().iter() {
if let Some(ty) = fields2.get(key) {
self.unify_impl(*ty, *value, false)?;
for (key, (ty, is_mutable)) in fields1.borrow().iter() {
if let Some((ty2, is_mutable2)) = fields2.get_mut(key) {
self.unify_impl(*ty2, *ty, false)?;
*is_mutable2 |= *is_mutable;
} else {
fields2.insert(*key, *value);
fields2.insert(*key, (*ty, *is_mutable));
}
}
}
@ -495,13 +496,19 @@ impl Unifier {
self.set_a_to_b(a, b);
}
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
for (k, v) in map.borrow().iter() {
let ty = fields
for (k, (ty, is_mutable)) in map.borrow().iter() {
let (ty2, is_mutable2) = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
self.unify_impl(ty, *v, false)?;
// typevar represents the usage of the variable
// it is OK to have immutable usage for mutable fields
// but cannot have mutable usage for immutable fields
if *is_mutable && !is_mutable2 {
return Err(format!("Field {} should be immutable", k));
}
self.unify_impl(*ty, ty2, false)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify_impl(x, b, false)?;
@ -510,16 +517,19 @@ impl Unifier {
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() {
let ty = fields
for (k, (ty, is_mutable)) in map.borrow().iter() {
let (ty2, is_mutable2) = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
if !matches!(self.get_ty(ty2).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k));
}
self.unify_impl(*v, ty, false)?;
if *is_mutable && !is_mutable2 {
return Err(format!("Field {} should be immutable", k));
}
self.unify_impl(*ty, ty2, false)?;
}
} else {
// require annotation...
@ -643,7 +653,9 @@ impl Unifier {
let fields = fields
.borrow()
.iter()
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
.map(|(k, (v, _))| {
format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))
})
.join(", ");
format!("record[{}]", fields)
}
@ -805,7 +817,7 @@ impl Unifier {
let params =
self.subst_map(&params, mapping, cache).unwrap_or_else(|| params.clone());
let fields = self
.subst_map(&fields.borrow(), mapping, cache)
.subst_map2(&fields.borrow(), mapping, cache)
.unwrap_or_else(|| fields.borrow().clone());
let new_ty = self.add_ty(TypeEnum::TObj {
obj_id,
@ -873,6 +885,27 @@ impl Unifier {
map2
}
fn subst_map2<K>(
&mut self,
map: &Mapping<K, (Type, bool)>,
mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>,
) -> Option<Mapping<K, (Type, bool)>>
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
let mut map2 = None;
for (k, (v, mutability)) in map.iter() {
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
if map2.is_none() {
map2 = Some(map.clone());
}
*map2.as_mut().unwrap().get_mut(k).unwrap() = (v1, *mutability);
}
}
map2
}
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
use TypeEnum::*;
let x = self.get_ty(a);

View File

@ -39,7 +39,7 @@ impl Unifier {
(
TypeEnum::TVar { meta: Record(fields1), .. },
TypeEnum::TVar { meta: Record(fields2), .. },
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
) => self.map_eq2(&fields1.borrow(), &fields2.borrow()),
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
@ -63,6 +63,25 @@ impl Unifier {
}
true
}
fn map_eq2<K>(
&mut self,
map1: &Mapping<K, (Type, bool)>,
map2: &Mapping<K, (Type, bool)>,
) -> bool
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
if map1.len() != map2.len() {
return false;
}
for (k, (ty1, m1)) in map1.iter() {
if !map2.get(k).map(|(ty2, m2)| m1 == m2 && self.eq(*ty1, *ty2)).unwrap_or(false) {
return false;
}
}
true
}
}
struct TestEnvironment {
@ -104,7 +123,11 @@ impl TestEnvironment {
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
fields: [("a".into(), (v0, true))]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
}),
);
@ -151,7 +174,7 @@ impl TestEnvironment {
let eq = s.find('=').unwrap();
let key = s[1..eq].into();
let result = self.internal_parse(&s[eq + 1..], mapping);
fields.insert(key, result.0);
fields.insert(key, (result.0, true));
s = result.1;
}
(self.unifier.add_record(fields), &s[1..])
@ -326,7 +349,7 @@ fn test_recursive_subst() {
let foo_ty = env.unifier.get_ty(foo_id);
let mapping: HashMap<_, _>;
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
fields.borrow_mut().insert("rec".into(), foo_id);
fields.borrow_mut().insert("rec".into(), (foo_id, true));
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
} else {
unreachable!()
@ -335,8 +358,8 @@ fn test_recursive_subst() {
let instantiated_ty = env.unifier.get_ty(instantiated);
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
let fields = fields.borrow();
assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int));
assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated));
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
} else {
unreachable!()
}
@ -351,7 +374,7 @@ fn test_virtual() {
));
let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("f".into(), fun), ("a".into(), int)]
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
.iter()
.cloned()
.collect::<HashMap<StrRef, _>>()
@ -363,15 +386,15 @@ fn test_virtual() {
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect());
let c = env.unifier.add_record([("f".into(), (v1, false))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect());
let d = env.unifier.add_record([("a".into(), (v1, true))].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect());
let d = env.unifier.add_record([("b".into(), (v1, true))].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
}