forked from M-Labs/nac3
nac3core: type check invariants
This rejects code that tries to assign to KernelInvariant fields and methods.
This commit is contained in:
parent
7385b91113
commit
b1e83a1fd4
|
@ -135,7 +135,7 @@ impl Resolver {
|
||||||
.collect();
|
.collect();
|
||||||
let mut fields_ty = HashMap::new();
|
let mut fields_ty = HashMap::new();
|
||||||
for method in methods.iter() {
|
for method in methods.iter() {
|
||||||
fields_ty.insert(method.0, method.1);
|
fields_ty.insert(method.0, (method.1, false));
|
||||||
}
|
}
|
||||||
for field in fields.iter() {
|
for field in fields.iter() {
|
||||||
let name: String = field.0.into();
|
let name: String = field.0.into();
|
||||||
|
@ -148,7 +148,7 @@ impl Resolver {
|
||||||
// field type mismatch
|
// field type mismatch
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
fields_ty.insert(field.0, ty);
|
fields_ty.insert(field.0, (ty, field.2));
|
||||||
}
|
}
|
||||||
for (_, ty) in var_map.iter() {
|
for (_, ty) in var_map.iter() {
|
||||||
// must be concrete type
|
// must be concrete type
|
||||||
|
@ -379,7 +379,7 @@ impl Resolver {
|
||||||
if let TopLevelDef::Class { fields, .. } = &*definition {
|
if let TopLevelDef::Class { fields, .. } = &*definition {
|
||||||
let values: Result<Option<Vec<_>>, _> = fields
|
let values: Result<Option<Vec<_>>, _> = fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, _)| {
|
.map(|(name, _, _)| {
|
||||||
self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx)
|
self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -413,7 +413,7 @@ impl SymbolResolver for Resolver {
|
||||||
id_to_type.get(&str).cloned().or_else(|| {
|
id_to_type.get(&str).cloned().or_else(|| {
|
||||||
let py_id = self.name_to_pyid.get(&str);
|
let py_id = self.name_to_pyid.get(&str);
|
||||||
let result = py_id.and_then(|id| {
|
let result = py_id.and_then(|id| {
|
||||||
self.pyid_to_type.read().get(&id).copied().or_else(|| {
|
self.pyid_to_type.read().get(id).copied().or_else(|| {
|
||||||
Python::with_gil(|py| -> PyResult<Option<Type>> {
|
Python::with_gil(|py| -> PyResult<Option<Type>> {
|
||||||
let obj: &PyAny = self.module.extract(py)?;
|
let obj: &PyAny = self.module.extract(py)?;
|
||||||
let members: &PyList = PyModule::import(py, "inspect")?
|
let members: &PyList = PyModule::import(py, "inspect")?
|
||||||
|
@ -491,7 +491,7 @@ impl SymbolResolver for Resolver {
|
||||||
let mut id_to_def = self.id_to_def.lock();
|
let mut id_to_def = self.id_to_def.lock();
|
||||||
id_to_def.get(&id).cloned().or_else(|| {
|
id_to_def.get(&id).cloned().or_else(|| {
|
||||||
let py_id = self.name_to_pyid.get(&id);
|
let py_id = self.name_to_pyid.get(&id);
|
||||||
let result = py_id.and_then(|id| self.pyid_to_def.read().get(&id).copied());
|
let result = py_id.and_then(|id| self.pyid_to_def.read().get(id).copied());
|
||||||
if let Some(result) = &result {
|
if let Some(result) = &result {
|
||||||
id_to_def.insert(id, *result);
|
id_to_def.insert(id, *result);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ pub enum ConcreteTypeEnum {
|
||||||
},
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: HashMap<StrRef, ConcreteType>,
|
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||||
params: HashMap<u32, ConcreteType>,
|
params: HashMap<u32, ConcreteType>,
|
||||||
},
|
},
|
||||||
TVirtual {
|
TVirtual {
|
||||||
|
@ -148,7 +148,7 @@ impl ConcreteTypeStore {
|
||||||
.borrow()
|
.borrow()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, ty)| {
|
.map(|(name, ty)| {
|
||||||
(*name, self.from_unifier_type(unifier, primitives, *ty, cache))
|
(*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
params: params
|
params: params
|
||||||
|
@ -225,7 +225,7 @@ impl ConcreteTypeStore {
|
||||||
fields: fields
|
fields: fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, cty)| {
|
.map(|(name, cty)| {
|
||||||
(*name, self.to_unifier_type(unifier, primitives, *cty, cache))
|
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
|
||||||
})
|
})
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
.into(),
|
.into(),
|
||||||
|
|
|
@ -232,7 +232,7 @@ fn get_llvm_type<'ctx>(
|
||||||
let fields = fields.borrow();
|
let fields = fields.borrow();
|
||||||
let fields = fields_list
|
let fields = fields_list
|
||||||
.iter()
|
.iter()
|
||||||
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
|
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0].0))
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -2,16 +2,19 @@ use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use crate::{codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}};
|
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::PrimitiveStore,
|
type_inferencer::PrimitiveStore,
|
||||||
typedef::{Type, Unifier},
|
typedef::{Type, Unifier},
|
||||||
};
|
};
|
||||||
|
use crate::{
|
||||||
|
codegen::CodeGenContext,
|
||||||
|
toplevel::{DefinitionId, TopLevelDef},
|
||||||
|
};
|
||||||
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
||||||
use itertools::{chain, izip};
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use nac3parser::ast::{Expr, StrRef};
|
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::values::BasicValueEnum;
|
||||||
|
use itertools::{chain, izip};
|
||||||
|
use nac3parser::ast::{Expr, StrRef};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug)]
|
#[derive(Clone, PartialEq, Debug)]
|
||||||
pub enum SymbolValue {
|
pub enum SymbolValue {
|
||||||
|
@ -35,7 +38,11 @@ pub trait SymbolResolver {
|
||||||
) -> Option<Type>;
|
) -> Option<Type>;
|
||||||
// get the top-level definition of identifiers
|
// get the top-level definition of identifiers
|
||||||
fn get_identifier_def(&self, str: StrRef) -> Option<DefinitionId>;
|
fn get_identifier_def(&self, str: StrRef) -> Option<DefinitionId>;
|
||||||
fn get_symbol_value<'ctx, 'a>(&self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>) -> Option<BasicValueEnum<'ctx>>;
|
fn get_symbol_value<'ctx, 'a>(
|
||||||
|
&self,
|
||||||
|
str: StrRef,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>>;
|
||||||
fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
|
fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
|
||||||
// handle function call etc.
|
// handle function call etc.
|
||||||
}
|
}
|
||||||
|
@ -62,9 +69,7 @@ pub fn parse_type_annotation<T>(
|
||||||
expr: &Expr<T>,
|
expr: &Expr<T>,
|
||||||
) -> Result<Type, String> {
|
) -> Result<Type, String> {
|
||||||
use nac3parser::ast::ExprKind::*;
|
use nac3parser::ast::ExprKind::*;
|
||||||
let ids = IDENTIFIER_ID.with(|ids| {
|
let ids = IDENTIFIER_ID.with(|ids| *ids);
|
||||||
*ids
|
|
||||||
});
|
|
||||||
let int32_id = ids[0];
|
let int32_id = ids[0];
|
||||||
let int64_id = ids[1];
|
let int64_id = ids[1];
|
||||||
let float_id = ids[2];
|
let float_id = ids[2];
|
||||||
|
@ -99,8 +104,8 @@ pub fn parse_type_annotation<T>(
|
||||||
}
|
}
|
||||||
let fields = RefCell::new(
|
let fields = RefCell::new(
|
||||||
chain(
|
chain(
|
||||||
fields.iter().map(|(k, v)| (*k, *v)),
|
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
|
||||||
methods.iter().map(|(k, v, _)| (*k, *v)),
|
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
|
||||||
)
|
)
|
||||||
.collect(),
|
.collect(),
|
||||||
);
|
);
|
||||||
|
@ -124,7 +129,7 @@ pub fn parse_type_annotation<T>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Subscript { value, slice, .. } => {
|
Subscript { value, slice, .. } => {
|
||||||
if let Name { id, .. } = &value.node {
|
if let Name { id, .. } = &value.node {
|
||||||
if *id == virtual_id {
|
if *id == virtual_id {
|
||||||
|
@ -209,14 +214,14 @@ pub fn parse_type_annotation<T>(
|
||||||
}
|
}
|
||||||
let mut fields = fields
|
let mut fields = fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(attr, ty)| {
|
.map(|(attr, ty, is_mutable)| {
|
||||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
(*attr, ty)
|
(*attr, (ty, *is_mutable))
|
||||||
})
|
})
|
||||||
.collect::<HashMap<_, _>>();
|
.collect::<HashMap<_, _>>();
|
||||||
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
||||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
(*attr, ty)
|
(*attr, (ty, false))
|
||||||
}));
|
}));
|
||||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id,
|
obj_id,
|
||||||
|
|
|
@ -461,7 +461,7 @@ impl TopLevelComposer {
|
||||||
"str".into(),
|
"str".into(),
|
||||||
"self".into(),
|
"self".into(),
|
||||||
"Kernel".into(),
|
"Kernel".into(),
|
||||||
"KernelImmutable".into(),
|
"KernelInvariant".into(),
|
||||||
]);
|
]);
|
||||||
let defined_names: HashSet<String> = Default::default();
|
let defined_names: HashSet<String> = Default::default();
|
||||||
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
|
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
|
||||||
|
@ -1391,22 +1391,24 @@ impl TopLevelComposer {
|
||||||
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
||||||
if defined_fields.insert(attr.to_string()) {
|
if defined_fields.insert(attr.to_string()) {
|
||||||
let dummy_field_type = unifier.get_fresh_var().0;
|
let dummy_field_type = unifier.get_fresh_var().0;
|
||||||
class_fields_def.push((*attr, dummy_field_type));
|
|
||||||
|
|
||||||
// handle Kernel[T], KernelImmutable[T]
|
// handle Kernel[T], KernelInvariant[T]
|
||||||
let annotation = {
|
let (annotation, mutable) = {
|
||||||
match &annotation.as_ref().node {
|
let mut result = None;
|
||||||
ast::ExprKind::Subscript { value, slice, .. }
|
if let ast::ExprKind::Subscript { value, slice, .. } = &annotation.as_ref().node {
|
||||||
if {
|
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||||
matches!(&value.node, ast::ExprKind::Name { id, .. }
|
result = if id == &"Kernel".into() {
|
||||||
if id == &"Kernel".into() || id == &"KernelImmutable".into())
|
Some((slice, true))
|
||||||
} =>
|
} else if id == &"KernelInvariant".into() {
|
||||||
{
|
Some((slice, false))
|
||||||
slice
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
_ => annotation,
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
result.unwrap_or((annotation, true))
|
||||||
};
|
};
|
||||||
|
class_fields_def.push((*attr, dummy_field_type, mutable));
|
||||||
|
|
||||||
let annotation = parse_ast_to_type_annotation_kinds(
|
let annotation = parse_ast_to_type_annotation_kinds(
|
||||||
class_resolver,
|
class_resolver,
|
||||||
|
@ -1532,26 +1534,13 @@ impl TopLevelComposer {
|
||||||
class_methods_def.extend(new_child_methods);
|
class_methods_def.extend(new_child_methods);
|
||||||
|
|
||||||
// handle class fields
|
// handle class fields
|
||||||
let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new();
|
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
|
||||||
// let mut is_override: HashSet<_> = HashSet::new();
|
// let mut is_override: HashSet<_> = HashSet::new();
|
||||||
for (anc_field_name, anc_field_ty) in fields {
|
for (anc_field_name, anc_field_ty, mutable) in fields {
|
||||||
let to_be_added = (*anc_field_name, *anc_field_ty);
|
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
|
||||||
// find if there is a fields with the same name in the child class
|
// find if there is a fields with the same name in the child class
|
||||||
for (class_field_name, ..) in class_fields_def.iter() {
|
for (class_field_name, ..) in class_fields_def.iter() {
|
||||||
if class_field_name == anc_field_name {
|
if class_field_name == anc_field_name {
|
||||||
// let ok = Self::check_overload_field_type(
|
|
||||||
// *class_field_ty,
|
|
||||||
// *anc_field_ty,
|
|
||||||
// unifier,
|
|
||||||
// type_var_to_concrete_def,
|
|
||||||
// );
|
|
||||||
// if !ok {
|
|
||||||
// return Err("fields has same name as ancestors' field, but incompatible type".into());
|
|
||||||
// }
|
|
||||||
// // mark it as added
|
|
||||||
// is_override.insert(class_field_name.to_string());
|
|
||||||
// to_be_added = (class_field_name.to_string(), *class_field_ty);
|
|
||||||
// break;
|
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"field `{}` has already declared in the ancestor classes",
|
"field `{}` has already declared in the ancestor classes",
|
||||||
class_field_name
|
class_field_name
|
||||||
|
@ -1560,9 +1549,9 @@ impl TopLevelComposer {
|
||||||
}
|
}
|
||||||
new_child_fields.push(to_be_added);
|
new_child_fields.push(to_be_added);
|
||||||
}
|
}
|
||||||
for (class_field_name, class_field_ty) in class_fields_def.iter() {
|
for (class_field_name, class_field_ty, mutable) in class_fields_def.iter() {
|
||||||
if !is_override.contains(class_field_name) {
|
if !is_override.contains(class_field_name) {
|
||||||
new_child_fields.push((*class_field_name, *class_field_ty));
|
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
class_fields_def.drain(..);
|
class_fields_def.drain(..);
|
||||||
|
@ -1636,7 +1625,7 @@ impl TopLevelComposer {
|
||||||
unreachable!("must be init function here")
|
unreachable!("must be init function here")
|
||||||
}
|
}
|
||||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||||
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
|
if fields.iter().any(|x| !all_inited.contains(&x.0)) {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"fields of class {} not fully initialized",
|
"fields of class {} not fully initialized",
|
||||||
class_name
|
class_name
|
||||||
|
|
|
@ -12,7 +12,7 @@ impl TopLevelDef {
|
||||||
} => {
|
} => {
|
||||||
let fields_str = fields
|
let fields_str = fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(n, ty)| {
|
.map(|(n, ty, _)| {
|
||||||
(n.to_string(), unifier.default_stringify(*ty))
|
(n.to_string(), unifier.default_stringify(*ty))
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
|
@ -85,7 +85,8 @@ pub enum TopLevelDef {
|
||||||
/// type variables bounded to the class.
|
/// type variables bounded to the class.
|
||||||
type_vars: Vec<Type>,
|
type_vars: Vec<Type>,
|
||||||
// class fields
|
// class fields
|
||||||
fields: Vec<(StrRef, Type)>,
|
// name, type, is mutable
|
||||||
|
fields: Vec<(StrRef, Type, bool)>,
|
||||||
// class methods, pointing to the corresponding function definition.
|
// class methods, pointing to the corresponding function definition.
|
||||||
methods: Vec<(StrRef, Type, DefinitionId)>,
|
methods: Vec<(StrRef, Type, DefinitionId)>,
|
||||||
// ancestor classes, including itself.
|
// ancestor classes, including itself.
|
||||||
|
|
|
@ -307,25 +307,15 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, ty, _)| {
|
.map(|(name, ty, _)| {
|
||||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
(*name, subst_ty)
|
// methods are immutable
|
||||||
|
(*name, (subst_ty, false))
|
||||||
})
|
})
|
||||||
.collect::<HashMap<_, Type>>();
|
.collect::<HashMap<_, _>>();
|
||||||
tobj_fields.extend(fields.iter().map(|(name, ty)| {
|
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
||||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
(*name, subst_ty)
|
(*name, (subst_ty, *mutability))
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// println!("tobj_fields: {:?}", tobj_fields);
|
|
||||||
// println!(
|
|
||||||
// "{:?}: {}\n",
|
|
||||||
// tobj_fields.get("__init__").unwrap(),
|
|
||||||
// unifier.stringify(
|
|
||||||
// *tobj_fields.get("__init__").unwrap(),
|
|
||||||
// &mut |id| format!("class{}", id),
|
|
||||||
// &mut |id| format!("tvar{}", id)
|
|
||||||
// )
|
|
||||||
// );
|
|
||||||
|
|
||||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
fields: RefCell::new(tobj_fields),
|
fields: RefCell::new(tobj_fields),
|
||||||
|
|
|
@ -86,6 +86,7 @@ pub fn impl_binop(
|
||||||
};
|
};
|
||||||
for op in ops {
|
for op in ops {
|
||||||
fields.borrow_mut().insert(binop_name(op).into(), {
|
fields.borrow_mut().insert(binop_name(op).into(), {
|
||||||
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
FunSignature {
|
FunSignature {
|
||||||
ret: ret_ty,
|
ret: ret_ty,
|
||||||
|
@ -97,10 +98,13 @@ pub fn impl_binop(
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
))
|
)),
|
||||||
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
fields.borrow_mut().insert(binop_assign_name(op).into(), {
|
fields.borrow_mut().insert(binop_assign_name(op).into(), {
|
||||||
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
FunSignature {
|
FunSignature {
|
||||||
ret: store.none,
|
ret: store.none,
|
||||||
|
@ -112,7 +116,9 @@ pub fn impl_binop(
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
))
|
)),
|
||||||
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -131,9 +137,12 @@ pub fn impl_unaryop(
|
||||||
for op in ops {
|
for op in ops {
|
||||||
fields.borrow_mut().insert(
|
fields.borrow_mut().insert(
|
||||||
unaryop_name(op).into(),
|
unaryop_name(op).into(),
|
||||||
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
|
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
|
||||||
)),
|
)),
|
||||||
|
false,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -152,6 +161,7 @@ pub fn impl_cmpop(
|
||||||
for op in ops {
|
for op in ops {
|
||||||
fields.borrow_mut().insert(
|
fields.borrow_mut().insert(
|
||||||
comparison_name(op).unwrap().into(),
|
comparison_name(op).unwrap().into(),
|
||||||
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(
|
unifier.add_ty(TypeEnum::TFunc(
|
||||||
FunSignature {
|
FunSignature {
|
||||||
ret: store.bool,
|
ret: store.bool,
|
||||||
|
@ -164,6 +174,8 @@ pub fn impl_cmpop(
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
)),
|
)),
|
||||||
|
false,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -10,7 +10,7 @@ use itertools::izip;
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
self,
|
self,
|
||||||
fold::{self, Fold},
|
fold::{self, Fold},
|
||||||
Arguments, Comprehension, ExprKind, Located, Location, StrRef,
|
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -77,11 +77,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
fn fold_stmt(
|
||||||
|
&mut self,
|
||||||
|
mut node: ast::Stmt<()>,
|
||||||
|
) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||||
let stmt = match node.node {
|
let stmt = match node.node {
|
||||||
// we don't want fold over type annotation
|
// we don't want fold over type annotation
|
||||||
ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment } => {
|
ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => {
|
||||||
self.infer_pattern(&target)?;
|
self.infer_pattern(&target)?;
|
||||||
|
// fix parser problem...
|
||||||
|
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
||||||
|
*ctx = ExprContext::Store;
|
||||||
|
}
|
||||||
|
|
||||||
let target = Box::new(self.fold_expr(*target)?);
|
let target = Box::new(self.fold_expr(*target)?);
|
||||||
let value = if let Some(v) = value {
|
let value = if let Some(v) = value {
|
||||||
let ty = Box::new(self.fold_expr(*v)?);
|
let ty = Box::new(self.fold_expr(*v)?);
|
||||||
|
@ -105,14 +113,25 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
Located {
|
Located {
|
||||||
location: node.location,
|
location: node.location,
|
||||||
custom: None,
|
custom: None,
|
||||||
node: ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment },
|
node: ast::StmtKind::AnnAssign {
|
||||||
|
target,
|
||||||
|
annotation,
|
||||||
|
value,
|
||||||
|
simple,
|
||||||
|
config_comment,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::For { ref target, .. } => {
|
ast::StmtKind::For { ref target, .. } => {
|
||||||
self.infer_pattern(target)?;
|
self.infer_pattern(target)?;
|
||||||
fold::fold_stmt(self, node)?
|
fold::fold_stmt(self, node)?
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assign { ref targets, ref config_comment, .. } => {
|
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
|
||||||
|
for target in targets.iter_mut() {
|
||||||
|
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
||||||
|
*ctx = ExprContext::Store;
|
||||||
|
}
|
||||||
|
}
|
||||||
if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) {
|
if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) {
|
||||||
if let ast::StmtKind::Assign { targets, value, .. } = node.node {
|
if let ast::StmtKind::Assign { targets, value, .. } = node.node {
|
||||||
let value = self.fold_expr(*value)?;
|
let value = self.fold_expr(*value)?;
|
||||||
|
@ -158,7 +177,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
targets,
|
targets,
|
||||||
value: Box::new(value),
|
value: Box::new(value),
|
||||||
type_comment: None,
|
type_comment: None,
|
||||||
config_comment: config_comment.clone()
|
config_comment: config_comment.clone(),
|
||||||
},
|
},
|
||||||
custom: None,
|
custom: None,
|
||||||
});
|
});
|
||||||
|
@ -199,7 +218,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Pass { .. } => {}
|
ast::StmtKind::Break { .. }
|
||||||
|
| ast::StmtKind::Continue { .. }
|
||||||
|
| ast::StmtKind::Pass { .. } => {}
|
||||||
ast::StmtKind::With { items, .. } => {
|
ast::StmtKind::With { items, .. } => {
|
||||||
for item in items.iter() {
|
for item in items.iter() {
|
||||||
let ty = item.context_expr.custom.unwrap();
|
let ty = item.context_expr.custom.unwrap();
|
||||||
|
@ -209,17 +230,21 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
let fields = fields.borrow();
|
let fields = fields.borrow();
|
||||||
fast_path = true;
|
fast_path = true;
|
||||||
if let Some(enter) = fields.get(&"__enter__".into()).cloned() {
|
if let Some(enter) = fields.get(&"__enter__".into()).cloned() {
|
||||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) {
|
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) {
|
||||||
let signature = signature.borrow();
|
let signature = signature.borrow();
|
||||||
if !signature.args.is_empty() {
|
if !signature.args.is_empty() {
|
||||||
return report_error(
|
return report_error(
|
||||||
"__enter__ method should take no argument other than self",
|
"__enter__ method should take no argument other than self",
|
||||||
stmt.location
|
stmt.location,
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
if let Some(var) = &item.optional_vars {
|
if let Some(var) = &item.optional_vars {
|
||||||
if signature.vars.is_empty() {
|
if signature.vars.is_empty() {
|
||||||
self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?;
|
self.unify(
|
||||||
|
signature.ret,
|
||||||
|
var.custom.unwrap(),
|
||||||
|
&stmt.location,
|
||||||
|
)?;
|
||||||
} else {
|
} else {
|
||||||
fast_path = false;
|
fast_path = false;
|
||||||
}
|
}
|
||||||
|
@ -230,17 +255,17 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
} else {
|
} else {
|
||||||
return report_error(
|
return report_error(
|
||||||
"__enter__ method is required for context manager",
|
"__enter__ method is required for context manager",
|
||||||
stmt.location
|
stmt.location,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if let Some(exit) = fields.get(&"__exit__".into()).cloned() {
|
if let Some(exit) = fields.get(&"__exit__".into()).cloned() {
|
||||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) {
|
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) {
|
||||||
let signature = signature.borrow();
|
let signature = signature.borrow();
|
||||||
if !signature.args.is_empty() {
|
if !signature.args.is_empty() {
|
||||||
return report_error(
|
return report_error(
|
||||||
"__exit__ method should take no argument other than self",
|
"__exit__ method should take no argument other than self",
|
||||||
stmt.location
|
stmt.location,
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fast_path = false;
|
fast_path = false;
|
||||||
|
@ -248,26 +273,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
} else {
|
} else {
|
||||||
return report_error(
|
return report_error(
|
||||||
"__exit__ method is required for context manager",
|
"__exit__ method is required for context manager",
|
||||||
stmt.location
|
stmt.location,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !fast_path {
|
if !fast_path {
|
||||||
let enter = TypeEnum::TFunc(RefCell::new(FunSignature {
|
let enter = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()),
|
ret: item.optional_vars.as_ref().map_or_else(
|
||||||
vars: Default::default()
|
|| self.unifier.get_fresh_var().0,
|
||||||
|
|var| var.custom.unwrap(),
|
||||||
|
),
|
||||||
|
vars: Default::default(),
|
||||||
}));
|
}));
|
||||||
let enter = self.unifier.add_ty(enter);
|
let enter = self.unifier.add_ty(enter);
|
||||||
let exit = TypeEnum::TFunc(RefCell::new(FunSignature {
|
let exit = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
ret: self.unifier.get_fresh_var().0,
|
ret: self.unifier.get_fresh_var().0,
|
||||||
vars: Default::default()
|
vars: Default::default(),
|
||||||
}));
|
}));
|
||||||
let exit = self.unifier.add_ty(exit);
|
let exit = self.unifier.add_ty(exit);
|
||||||
let mut fields = HashMap::new();
|
let mut fields = HashMap::new();
|
||||||
fields.insert("__enter__".into(), enter);
|
fields.insert("__enter__".into(), (enter, false));
|
||||||
fields.insert("__exit__".into(), exit);
|
fields.insert("__exit__".into(), (exit, false));
|
||||||
let record = self.unifier.add_record(fields);
|
let record = self.unifier.add_record(fields);
|
||||||
self.unify(ty, record, &stmt.location)?;
|
self.unify(ty, record, &stmt.location)?;
|
||||||
}
|
}
|
||||||
|
@ -330,8 +358,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||||
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
ast::ExprKind::Attribute { value, attr, ctx } => {
|
||||||
Some(self.infer_attribute(value, *attr)?)
|
Some(self.infer_attribute(value, *attr, ctx)?)
|
||||||
}
|
}
|
||||||
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||||
ast::ExprKind::BinOp { left, op, right } => {
|
ast::ExprKind::BinOp { left, op, right } => {
|
||||||
|
@ -399,7 +427,8 @@ impl<'a> Inferencer<'a> {
|
||||||
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
||||||
if class_params.borrow().is_empty() {
|
if class_params.borrow().is_empty() {
|
||||||
if let Some(ty) = fields.borrow().get(&method) {
|
if let Some(ty) = fields.borrow().get(&method) {
|
||||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) {
|
let ty = ty.0;
|
||||||
|
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
|
||||||
let sign = sign.borrow();
|
let sign = sign.borrow();
|
||||||
if sign.vars.is_empty() {
|
if sign.vars.is_empty() {
|
||||||
let call = Call {
|
let call = Call {
|
||||||
|
@ -419,7 +448,7 @@ impl<'a> Inferencer<'a> {
|
||||||
.rev()
|
.rev()
|
||||||
.collect();
|
.collect();
|
||||||
self.unifier
|
self.unifier
|
||||||
.unify_call(&call, *ty, &sign, &required)
|
.unify_call(&call, ty, &sign, &required)
|
||||||
.map_err(|old| format!("{} at {}", old, location))?;
|
.map_err(|old| format!("{} at {}", old, location))?;
|
||||||
return Ok(sign.ret);
|
return Ok(sign.ret);
|
||||||
}
|
}
|
||||||
|
@ -437,7 +466,7 @@ impl<'a> Inferencer<'a> {
|
||||||
});
|
});
|
||||||
self.calls.insert(location.into(), call);
|
self.calls.insert(location.into(), call);
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
let fields = once((method, call)).collect();
|
let fields = once((method, (call, false))).collect();
|
||||||
let record = self.unifier.add_record(fields);
|
let record = self.unifier.add_record(fields);
|
||||||
self.constrain(obj, record, &location)?;
|
self.constrain(obj, record, &location)?;
|
||||||
Ok(ret)
|
Ok(ret)
|
||||||
|
@ -538,7 +567,11 @@ impl<'a> Inferencer<'a> {
|
||||||
let target = new_context.fold_expr(*generator.target)?;
|
let target = new_context.fold_expr(*generator.target)?;
|
||||||
let iter = new_context.fold_expr(*generator.iter)?;
|
let iter = new_context.fold_expr(*generator.iter)?;
|
||||||
if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) {
|
if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) {
|
||||||
new_context.unify(target.custom.unwrap(), new_context.primitives.int32, &target.location)?;
|
new_context.unify(
|
||||||
|
target.custom.unwrap(),
|
||||||
|
new_context.primitives.int32,
|
||||||
|
&target.location,
|
||||||
|
)?;
|
||||||
} else {
|
} else {
|
||||||
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||||
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
|
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
|
||||||
|
@ -755,13 +788,28 @@ impl<'a> Inferencer<'a> {
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
value: &ast::Expr<Option<Type>>,
|
||||||
attr: StrRef,
|
attr: StrRef,
|
||||||
|
ctx: &ExprContext,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
|
let ty = value.custom.unwrap();
|
||||||
|
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
|
||||||
|
// just a fast path
|
||||||
|
let fields = fields.borrow();
|
||||||
|
match (fields.get(&attr), ctx == &ExprContext::Store) {
|
||||||
|
(Some((ty, true)), _) => Ok(*ty),
|
||||||
|
(Some((ty, false)), false) => Ok(*ty),
|
||||||
|
(Some((_, false)), true) => {
|
||||||
|
report_error(&format!("Field {} should be immutable", attr), value.location)
|
||||||
|
}
|
||||||
|
(None, _) => report_error(&format!("No such field {}", attr), value.location),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
let (attr_ty, _) = self.unifier.get_fresh_var();
|
let (attr_ty, _) = self.unifier.get_fresh_var();
|
||||||
let fields = once((attr, attr_ty)).collect();
|
let fields = once((attr, (attr_ty, ctx == &ExprContext::Store))).collect();
|
||||||
let record = self.unifier.add_record(fields);
|
let record = self.unifier.add_record(fields);
|
||||||
self.constrain(value.custom.unwrap(), record, &value.location)?;
|
self.constrain(value.custom.unwrap(), record, &value.location)?;
|
||||||
Ok(attr_ty)
|
Ok(attr_ty)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||||
let b = self.primitives.bool;
|
let b = self.primitives.bool;
|
||||||
|
|
|
@ -8,8 +8,8 @@ use crate::{
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::values::BasicValueEnum;
|
||||||
use itertools::zip;
|
use itertools::zip;
|
||||||
use parking_lot::RwLock;
|
|
||||||
use nac3parser::parser::parse_program;
|
use nac3parser::parser::parse_program;
|
||||||
|
use parking_lot::RwLock;
|
||||||
use test_case::test_case;
|
use test_case::test_case;
|
||||||
|
|
||||||
struct Resolver {
|
struct Resolver {
|
||||||
|
@ -75,7 +75,7 @@ impl TestEnvironment {
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
));
|
));
|
||||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
|
||||||
}
|
}
|
||||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(1),
|
obj_id: DefinitionId(1),
|
||||||
|
@ -170,7 +170,7 @@ impl TestEnvironment {
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
));
|
));
|
||||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
|
||||||
}
|
}
|
||||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(1),
|
obj_id: DefinitionId(1),
|
||||||
|
@ -203,7 +203,9 @@ impl TestEnvironment {
|
||||||
params: HashMap::new().into(),
|
params: HashMap::new().into(),
|
||||||
});
|
});
|
||||||
identifier_mapping.insert("None".into(), none);
|
identifier_mapping.insert("None".into(), none);
|
||||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate() {
|
for (i, name) in
|
||||||
|
["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate()
|
||||||
|
{
|
||||||
top_level_defs.push(
|
top_level_defs.push(
|
||||||
RwLock::new(TopLevelDef::Class {
|
RwLock::new(TopLevelDef::Class {
|
||||||
name: (*name).into(),
|
name: (*name).into(),
|
||||||
|
@ -225,7 +227,7 @@ impl TestEnvironment {
|
||||||
|
|
||||||
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(7),
|
obj_id: DefinitionId(7),
|
||||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
});
|
});
|
||||||
top_level_defs.push(
|
top_level_defs.push(
|
||||||
|
@ -233,7 +235,7 @@ impl TestEnvironment {
|
||||||
name: "Foo".into(),
|
name: "Foo".into(),
|
||||||
object_id: DefinitionId(7),
|
object_id: DefinitionId(7),
|
||||||
type_vars: vec![v0],
|
type_vars: vec![v0],
|
||||||
fields: [("a".into(), v0)].into(),
|
fields: [("a".into(), v0, true)].into(),
|
||||||
methods: Default::default(),
|
methods: Default::default(),
|
||||||
ancestors: Default::default(),
|
ancestors: Default::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -259,7 +261,7 @@ impl TestEnvironment {
|
||||||
));
|
));
|
||||||
let bar = unifier.add_ty(TypeEnum::TObj {
|
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(8),
|
obj_id: DefinitionId(8),
|
||||||
fields: [("a".into(), int32), ("b".into(), fun)]
|
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
|
||||||
.iter()
|
.iter()
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
|
@ -271,7 +273,7 @@ impl TestEnvironment {
|
||||||
name: "Bar".into(),
|
name: "Bar".into(),
|
||||||
object_id: DefinitionId(8),
|
object_id: DefinitionId(8),
|
||||||
type_vars: Default::default(),
|
type_vars: Default::default(),
|
||||||
fields: [("a".into(), int32), ("b".into(), fun)].into(),
|
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
|
||||||
methods: Default::default(),
|
methods: Default::default(),
|
||||||
ancestors: Default::default(),
|
ancestors: Default::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -288,7 +290,7 @@ impl TestEnvironment {
|
||||||
|
|
||||||
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(9),
|
obj_id: DefinitionId(9),
|
||||||
fields: [("a".into(), bool), ("b".into(), fun)]
|
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
|
||||||
.iter()
|
.iter()
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
|
@ -300,7 +302,7 @@ impl TestEnvironment {
|
||||||
name: "Bar2".into(),
|
name: "Bar2".into(),
|
||||||
object_id: DefinitionId(9),
|
object_id: DefinitionId(9),
|
||||||
type_vars: Default::default(),
|
type_vars: Default::default(),
|
||||||
fields: [("a".into(), bool), ("b".into(), fun)].into(),
|
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
|
||||||
methods: Default::default(),
|
methods: Default::default(),
|
||||||
ancestors: Default::default(),
|
ancestors: Default::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
|
|
@ -49,7 +49,7 @@ pub struct FunSignature {
|
||||||
pub enum TypeVarMeta {
|
pub enum TypeVarMeta {
|
||||||
Generic,
|
Generic,
|
||||||
Sequence(RefCell<Mapping<i32>>),
|
Sequence(RefCell<Mapping<i32>>),
|
||||||
Record(RefCell<Mapping<StrRef>>),
|
Record(RefCell<Mapping<StrRef, (Type, bool)>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -71,7 +71,7 @@ pub enum TypeEnum {
|
||||||
},
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: RefCell<Mapping<StrRef>>,
|
fields: RefCell<Mapping<StrRef, (Type, bool)>>,
|
||||||
params: RefCell<VarMap>,
|
params: RefCell<VarMap>,
|
||||||
},
|
},
|
||||||
TVirtual {
|
TVirtual {
|
||||||
|
@ -155,7 +155,7 @@ impl Unifier {
|
||||||
self.unification_table.new_key(Rc::new(a))
|
self.unification_table.new_key(Rc::new(a))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_record(&mut self, fields: Mapping<StrRef>) -> Type {
|
pub fn add_record(&mut self, fields: Mapping<StrRef, (Type, bool)>) -> Type {
|
||||||
let id = self.var_id + 1;
|
let id = self.var_id + 1;
|
||||||
self.var_id += 1;
|
self.var_id += 1;
|
||||||
self.add_ty(TypeEnum::TVar {
|
self.add_ty(TypeEnum::TVar {
|
||||||
|
@ -394,11 +394,12 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
(Record(fields1), Record(fields2)) => {
|
(Record(fields1), Record(fields2)) => {
|
||||||
let mut fields2 = fields2.borrow_mut();
|
let mut fields2 = fields2.borrow_mut();
|
||||||
for (key, value) in fields1.borrow().iter() {
|
for (key, (ty, is_mutable)) in fields1.borrow().iter() {
|
||||||
if let Some(ty) = fields2.get(key) {
|
if let Some((ty2, is_mutable2)) = fields2.get_mut(key) {
|
||||||
self.unify_impl(*ty, *value, false)?;
|
self.unify_impl(*ty2, *ty, false)?;
|
||||||
|
*is_mutable2 |= *is_mutable;
|
||||||
} else {
|
} else {
|
||||||
fields2.insert(*key, *value);
|
fields2.insert(*key, (*ty, *is_mutable));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -495,13 +496,19 @@ impl Unifier {
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||||
for (k, v) in map.borrow().iter() {
|
for (k, (ty, is_mutable)) in map.borrow().iter() {
|
||||||
let ty = fields
|
let (ty2, is_mutable2) = fields
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(k)
|
.get(k)
|
||||||
.copied()
|
.copied()
|
||||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
self.unify_impl(ty, *v, false)?;
|
// typevar represents the usage of the variable
|
||||||
|
// it is OK to have immutable usage for mutable fields
|
||||||
|
// but cannot have mutable usage for immutable fields
|
||||||
|
if *is_mutable && !is_mutable2 {
|
||||||
|
return Err(format!("Field {} should be immutable", k));
|
||||||
|
}
|
||||||
|
self.unify_impl(*ty, ty2, false)?;
|
||||||
}
|
}
|
||||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
|
@ -510,16 +517,19 @@ impl Unifier {
|
||||||
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
||||||
let ty = self.get_ty(*ty);
|
let ty = self.get_ty(*ty);
|
||||||
if let TObj { fields, .. } = ty.as_ref() {
|
if let TObj { fields, .. } = ty.as_ref() {
|
||||||
for (k, v) in map.borrow().iter() {
|
for (k, (ty, is_mutable)) in map.borrow().iter() {
|
||||||
let ty = fields
|
let (ty2, is_mutable2) = fields
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(k)
|
.get(k)
|
||||||
.copied()
|
.copied()
|
||||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
if !matches!(self.get_ty(ty2).as_ref(), TFunc { .. }) {
|
||||||
return Err(format!("Cannot access field {} for virtual type", k));
|
return Err(format!("Cannot access field {} for virtual type", k));
|
||||||
}
|
}
|
||||||
self.unify_impl(*v, ty, false)?;
|
if *is_mutable && !is_mutable2 {
|
||||||
|
return Err(format!("Field {} should be immutable", k));
|
||||||
|
}
|
||||||
|
self.unify_impl(*ty, ty2, false)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// require annotation...
|
// require annotation...
|
||||||
|
@ -643,7 +653,9 @@ impl Unifier {
|
||||||
let fields = fields
|
let fields = fields
|
||||||
.borrow()
|
.borrow()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
.map(|(k, (v, _))| {
|
||||||
|
format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))
|
||||||
|
})
|
||||||
.join(", ");
|
.join(", ");
|
||||||
format!("record[{}]", fields)
|
format!("record[{}]", fields)
|
||||||
}
|
}
|
||||||
|
@ -805,7 +817,7 @@ impl Unifier {
|
||||||
let params =
|
let params =
|
||||||
self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
||||||
let fields = self
|
let fields = self
|
||||||
.subst_map(&fields.borrow(), mapping, cache)
|
.subst_map2(&fields.borrow(), mapping, cache)
|
||||||
.unwrap_or_else(|| fields.borrow().clone());
|
.unwrap_or_else(|| fields.borrow().clone());
|
||||||
let new_ty = self.add_ty(TypeEnum::TObj {
|
let new_ty = self.add_ty(TypeEnum::TObj {
|
||||||
obj_id,
|
obj_id,
|
||||||
|
@ -873,6 +885,27 @@ impl Unifier {
|
||||||
map2
|
map2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn subst_map2<K>(
|
||||||
|
&mut self,
|
||||||
|
map: &Mapping<K, (Type, bool)>,
|
||||||
|
mapping: &VarMap,
|
||||||
|
cache: &mut HashMap<Type, Option<Type>>,
|
||||||
|
) -> Option<Mapping<K, (Type, bool)>>
|
||||||
|
where
|
||||||
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
|
{
|
||||||
|
let mut map2 = None;
|
||||||
|
for (k, (v, mutability)) in map.iter() {
|
||||||
|
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||||
|
if map2.is_none() {
|
||||||
|
map2 = Some(map.clone());
|
||||||
|
}
|
||||||
|
*map2.as_mut().unwrap().get_mut(k).unwrap() = (v1, *mutability);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map2
|
||||||
|
}
|
||||||
|
|
||||||
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
||||||
use TypeEnum::*;
|
use TypeEnum::*;
|
||||||
let x = self.get_ty(a);
|
let x = self.get_ty(a);
|
||||||
|
|
|
@ -39,7 +39,7 @@ impl Unifier {
|
||||||
(
|
(
|
||||||
TypeEnum::TVar { meta: Record(fields1), .. },
|
TypeEnum::TVar { meta: Record(fields1), .. },
|
||||||
TypeEnum::TVar { meta: Record(fields2), .. },
|
TypeEnum::TVar { meta: Record(fields2), .. },
|
||||||
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
|
) => self.map_eq2(&fields1.borrow(), &fields2.borrow()),
|
||||||
(
|
(
|
||||||
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||||
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||||
|
@ -63,6 +63,25 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn map_eq2<K>(
|
||||||
|
&mut self,
|
||||||
|
map1: &Mapping<K, (Type, bool)>,
|
||||||
|
map2: &Mapping<K, (Type, bool)>,
|
||||||
|
) -> bool
|
||||||
|
where
|
||||||
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
|
{
|
||||||
|
if map1.len() != map2.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (k, (ty1, m1)) in map1.iter() {
|
||||||
|
if !map2.get(k).map(|(ty2, m2)| m1 == m2 && self.eq(*ty1, *ty2)).unwrap_or(false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestEnvironment {
|
struct TestEnvironment {
|
||||||
|
@ -104,7 +123,11 @@ impl TestEnvironment {
|
||||||
"Foo".into(),
|
"Foo".into(),
|
||||||
unifier.add_ty(TypeEnum::TObj {
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(3),
|
obj_id: DefinitionId(3),
|
||||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
fields: [("a".into(), (v0, true))]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
.into(),
|
||||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
@ -151,7 +174,7 @@ impl TestEnvironment {
|
||||||
let eq = s.find('=').unwrap();
|
let eq = s.find('=').unwrap();
|
||||||
let key = s[1..eq].into();
|
let key = s[1..eq].into();
|
||||||
let result = self.internal_parse(&s[eq + 1..], mapping);
|
let result = self.internal_parse(&s[eq + 1..], mapping);
|
||||||
fields.insert(key, result.0);
|
fields.insert(key, (result.0, true));
|
||||||
s = result.1;
|
s = result.1;
|
||||||
}
|
}
|
||||||
(self.unifier.add_record(fields), &s[1..])
|
(self.unifier.add_record(fields), &s[1..])
|
||||||
|
@ -326,7 +349,7 @@ fn test_recursive_subst() {
|
||||||
let foo_ty = env.unifier.get_ty(foo_id);
|
let foo_ty = env.unifier.get_ty(foo_id);
|
||||||
let mapping: HashMap<_, _>;
|
let mapping: HashMap<_, _>;
|
||||||
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
|
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
|
||||||
fields.borrow_mut().insert("rec".into(), foo_id);
|
fields.borrow_mut().insert("rec".into(), (foo_id, true));
|
||||||
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
|
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
|
@ -335,8 +358,8 @@ fn test_recursive_subst() {
|
||||||
let instantiated_ty = env.unifier.get_ty(instantiated);
|
let instantiated_ty = env.unifier.get_ty(instantiated);
|
||||||
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
||||||
let fields = fields.borrow();
|
let fields = fields.borrow();
|
||||||
assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int));
|
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
|
||||||
assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated));
|
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
@ -351,7 +374,7 @@ fn test_virtual() {
|
||||||
));
|
));
|
||||||
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: DefinitionId(5),
|
obj_id: DefinitionId(5),
|
||||||
fields: [("f".into(), fun), ("a".into(), int)]
|
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
|
||||||
.iter()
|
.iter()
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<HashMap<StrRef, _>>()
|
.collect::<HashMap<StrRef, _>>()
|
||||||
|
@ -363,15 +386,15 @@ fn test_virtual() {
|
||||||
|
|
||||||
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
||||||
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
||||||
let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect());
|
let c = env.unifier.add_record([("f".into(), (v1, false))].iter().cloned().collect());
|
||||||
env.unifier.unify(a, b).unwrap();
|
env.unifier.unify(a, b).unwrap();
|
||||||
env.unifier.unify(b, c).unwrap();
|
env.unifier.unify(b, c).unwrap();
|
||||||
assert!(env.unifier.eq(v1, fun));
|
assert!(env.unifier.eq(v1, fun));
|
||||||
|
|
||||||
let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect());
|
let d = env.unifier.add_record([("a".into(), (v1, true))].iter().cloned().collect());
|
||||||
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
|
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
|
||||||
|
|
||||||
let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect());
|
let d = env.unifier.add_record([("b".into(), (v1, true))].iter().cloned().collect());
|
||||||
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
|
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue