forked from M-Labs/nac3
nac3core: type check invariants
This rejects code that tries to assign to KernelInvariant fields and methods.
This commit is contained in:
parent
7385b91113
commit
b1e83a1fd4
|
@ -135,7 +135,7 @@ impl Resolver {
|
|||
.collect();
|
||||
let mut fields_ty = HashMap::new();
|
||||
for method in methods.iter() {
|
||||
fields_ty.insert(method.0, method.1);
|
||||
fields_ty.insert(method.0, (method.1, false));
|
||||
}
|
||||
for field in fields.iter() {
|
||||
let name: String = field.0.into();
|
||||
|
@ -148,7 +148,7 @@ impl Resolver {
|
|||
// field type mismatch
|
||||
return Ok(None);
|
||||
}
|
||||
fields_ty.insert(field.0, ty);
|
||||
fields_ty.insert(field.0, (ty, field.2));
|
||||
}
|
||||
for (_, ty) in var_map.iter() {
|
||||
// must be concrete type
|
||||
|
@ -379,7 +379,7 @@ impl Resolver {
|
|||
if let TopLevelDef::Class { fields, .. } = &*definition {
|
||||
let values: Result<Option<Vec<_>>, _> = fields
|
||||
.iter()
|
||||
.map(|(name, _)| {
|
||||
.map(|(name, _, _)| {
|
||||
self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx)
|
||||
})
|
||||
.collect();
|
||||
|
@ -413,7 +413,7 @@ impl SymbolResolver for Resolver {
|
|||
id_to_type.get(&str).cloned().or_else(|| {
|
||||
let py_id = self.name_to_pyid.get(&str);
|
||||
let result = py_id.and_then(|id| {
|
||||
self.pyid_to_type.read().get(&id).copied().or_else(|| {
|
||||
self.pyid_to_type.read().get(id).copied().or_else(|| {
|
||||
Python::with_gil(|py| -> PyResult<Option<Type>> {
|
||||
let obj: &PyAny = self.module.extract(py)?;
|
||||
let members: &PyList = PyModule::import(py, "inspect")?
|
||||
|
@ -491,7 +491,7 @@ impl SymbolResolver for Resolver {
|
|||
let mut id_to_def = self.id_to_def.lock();
|
||||
id_to_def.get(&id).cloned().or_else(|| {
|
||||
let py_id = self.name_to_pyid.get(&id);
|
||||
let result = py_id.and_then(|id| self.pyid_to_def.read().get(&id).copied());
|
||||
let result = py_id.and_then(|id| self.pyid_to_def.read().get(id).copied());
|
||||
if let Some(result) = &result {
|
||||
id_to_def.insert(id, *result);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ pub enum ConcreteTypeEnum {
|
|||
},
|
||||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: HashMap<StrRef, ConcreteType>,
|
||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||
params: HashMap<u32, ConcreteType>,
|
||||
},
|
||||
TVirtual {
|
||||
|
@ -148,7 +148,7 @@ impl ConcreteTypeStore {
|
|||
.borrow()
|
||||
.iter()
|
||||
.map(|(name, ty)| {
|
||||
(*name, self.from_unifier_type(unifier, primitives, *ty, cache))
|
||||
(*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))
|
||||
})
|
||||
.collect(),
|
||||
params: params
|
||||
|
@ -225,7 +225,7 @@ impl ConcreteTypeStore {
|
|||
fields: fields
|
||||
.iter()
|
||||
.map(|(name, cty)| {
|
||||
(*name, self.to_unifier_type(unifier, primitives, *cty, cache))
|
||||
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
|
||||
})
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
|
|
|
@ -232,7 +232,7 @@ fn get_llvm_type<'ctx>(
|
|||
let fields = fields.borrow();
|
||||
let fields = fields_list
|
||||
.iter()
|
||||
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0]))
|
||||
.map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0].0))
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
} else {
|
||||
|
|
|
@ -2,16 +2,19 @@ use std::collections::HashMap;
|
|||
use std::fmt::Debug;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use crate::{codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, Unifier},
|
||||
};
|
||||
use crate::{
|
||||
codegen::CodeGenContext,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
};
|
||||
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
||||
use itertools::{chain, izip};
|
||||
use parking_lot::RwLock;
|
||||
use nac3parser::ast::{Expr, StrRef};
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::{chain, izip};
|
||||
use nac3parser::ast::{Expr, StrRef};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum SymbolValue {
|
||||
|
@ -35,7 +38,11 @@ pub trait SymbolResolver {
|
|||
) -> Option<Type>;
|
||||
// get the top-level definition of identifiers
|
||||
fn get_identifier_def(&self, str: StrRef) -> Option<DefinitionId>;
|
||||
fn get_symbol_value<'ctx, 'a>(&self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>) -> Option<BasicValueEnum<'ctx>>;
|
||||
fn get_symbol_value<'ctx, 'a>(
|
||||
&self,
|
||||
str: StrRef,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
) -> Option<BasicValueEnum<'ctx>>;
|
||||
fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
|
||||
// handle function call etc.
|
||||
}
|
||||
|
@ -62,9 +69,7 @@ pub fn parse_type_annotation<T>(
|
|||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
use nac3parser::ast::ExprKind::*;
|
||||
let ids = IDENTIFIER_ID.with(|ids| {
|
||||
*ids
|
||||
});
|
||||
let ids = IDENTIFIER_ID.with(|ids| *ids);
|
||||
let int32_id = ids[0];
|
||||
let int64_id = ids[1];
|
||||
let float_id = ids[2];
|
||||
|
@ -99,8 +104,8 @@ pub fn parse_type_annotation<T>(
|
|||
}
|
||||
let fields = RefCell::new(
|
||||
chain(
|
||||
fields.iter().map(|(k, v)| (*k, *v)),
|
||||
methods.iter().map(|(k, v, _)| (*k, *v)),
|
||||
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
|
||||
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
|
||||
)
|
||||
.collect(),
|
||||
);
|
||||
|
@ -124,7 +129,7 @@ pub fn parse_type_annotation<T>(
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Subscript { value, slice, .. } => {
|
||||
if let Name { id, .. } = &value.node {
|
||||
if *id == virtual_id {
|
||||
|
@ -209,14 +214,14 @@ pub fn parse_type_annotation<T>(
|
|||
}
|
||||
let mut fields = fields
|
||||
.iter()
|
||||
.map(|(attr, ty)| {
|
||||
.map(|(attr, ty, is_mutable)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*attr, ty)
|
||||
(*attr, (ty, *is_mutable))
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*attr, ty)
|
||||
(*attr, (ty, false))
|
||||
}));
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
|
|
|
@ -461,7 +461,7 @@ impl TopLevelComposer {
|
|||
"str".into(),
|
||||
"self".into(),
|
||||
"Kernel".into(),
|
||||
"KernelImmutable".into(),
|
||||
"KernelInvariant".into(),
|
||||
]);
|
||||
let defined_names: HashSet<String> = Default::default();
|
||||
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
|
||||
|
@ -1391,22 +1391,24 @@ impl TopLevelComposer {
|
|||
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
||||
if defined_fields.insert(attr.to_string()) {
|
||||
let dummy_field_type = unifier.get_fresh_var().0;
|
||||
class_fields_def.push((*attr, dummy_field_type));
|
||||
|
||||
// handle Kernel[T], KernelImmutable[T]
|
||||
let annotation = {
|
||||
match &annotation.as_ref().node {
|
||||
ast::ExprKind::Subscript { value, slice, .. }
|
||||
if {
|
||||
matches!(&value.node, ast::ExprKind::Name { id, .. }
|
||||
if id == &"Kernel".into() || id == &"KernelImmutable".into())
|
||||
} =>
|
||||
{
|
||||
slice
|
||||
// handle Kernel[T], KernelInvariant[T]
|
||||
let (annotation, mutable) = {
|
||||
let mut result = None;
|
||||
if let ast::ExprKind::Subscript { value, slice, .. } = &annotation.as_ref().node {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
result = if id == &"Kernel".into() {
|
||||
Some((slice, true))
|
||||
} else if id == &"KernelInvariant".into() {
|
||||
Some((slice, false))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
_ => annotation,
|
||||
}
|
||||
}
|
||||
result.unwrap_or((annotation, true))
|
||||
};
|
||||
class_fields_def.push((*attr, dummy_field_type, mutable));
|
||||
|
||||
let annotation = parse_ast_to_type_annotation_kinds(
|
||||
class_resolver,
|
||||
|
@ -1532,26 +1534,13 @@ impl TopLevelComposer {
|
|||
class_methods_def.extend(new_child_methods);
|
||||
|
||||
// handle class fields
|
||||
let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new();
|
||||
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
|
||||
// let mut is_override: HashSet<_> = HashSet::new();
|
||||
for (anc_field_name, anc_field_ty) in fields {
|
||||
let to_be_added = (*anc_field_name, *anc_field_ty);
|
||||
for (anc_field_name, anc_field_ty, mutable) in fields {
|
||||
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
|
||||
// find if there is a fields with the same name in the child class
|
||||
for (class_field_name, ..) in class_fields_def.iter() {
|
||||
if class_field_name == anc_field_name {
|
||||
// let ok = Self::check_overload_field_type(
|
||||
// *class_field_ty,
|
||||
// *anc_field_ty,
|
||||
// unifier,
|
||||
// type_var_to_concrete_def,
|
||||
// );
|
||||
// if !ok {
|
||||
// return Err("fields has same name as ancestors' field, but incompatible type".into());
|
||||
// }
|
||||
// // mark it as added
|
||||
// is_override.insert(class_field_name.to_string());
|
||||
// to_be_added = (class_field_name.to_string(), *class_field_ty);
|
||||
// break;
|
||||
return Err(format!(
|
||||
"field `{}` has already declared in the ancestor classes",
|
||||
class_field_name
|
||||
|
@ -1560,9 +1549,9 @@ impl TopLevelComposer {
|
|||
}
|
||||
new_child_fields.push(to_be_added);
|
||||
}
|
||||
for (class_field_name, class_field_ty) in class_fields_def.iter() {
|
||||
for (class_field_name, class_field_ty, mutable) in class_fields_def.iter() {
|
||||
if !is_override.contains(class_field_name) {
|
||||
new_child_fields.push((*class_field_name, *class_field_ty));
|
||||
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
|
||||
}
|
||||
}
|
||||
class_fields_def.drain(..);
|
||||
|
@ -1636,7 +1625,7 @@ impl TopLevelComposer {
|
|||
unreachable!("must be init function here")
|
||||
}
|
||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
|
||||
if fields.iter().any(|x| !all_inited.contains(&x.0)) {
|
||||
return Err(format!(
|
||||
"fields of class {} not fully initialized",
|
||||
class_name
|
||||
|
|
|
@ -12,7 +12,7 @@ impl TopLevelDef {
|
|||
} => {
|
||||
let fields_str = fields
|
||||
.iter()
|
||||
.map(|(n, ty)| {
|
||||
.map(|(n, ty, _)| {
|
||||
(n.to_string(), unifier.default_stringify(*ty))
|
||||
})
|
||||
.collect_vec();
|
||||
|
|
|
@ -85,7 +85,8 @@ pub enum TopLevelDef {
|
|||
/// type variables bounded to the class.
|
||||
type_vars: Vec<Type>,
|
||||
// class fields
|
||||
fields: Vec<(StrRef, Type)>,
|
||||
// name, type, is mutable
|
||||
fields: Vec<(StrRef, Type, bool)>,
|
||||
// class methods, pointing to the corresponding function definition.
|
||||
methods: Vec<(StrRef, Type, DefinitionId)>,
|
||||
// ancestor classes, including itself.
|
||||
|
|
|
@ -307,25 +307,15 @@ pub fn get_type_from_type_annotation_kinds(
|
|||
.iter()
|
||||
.map(|(name, ty, _)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*name, subst_ty)
|
||||
// methods are immutable
|
||||
(*name, (subst_ty, false))
|
||||
})
|
||||
.collect::<HashMap<_, Type>>();
|
||||
tobj_fields.extend(fields.iter().map(|(name, ty)| {
|
||||
.collect::<HashMap<_, _>>();
|
||||
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*name, subst_ty)
|
||||
(*name, (subst_ty, *mutability))
|
||||
}));
|
||||
|
||||
// println!("tobj_fields: {:?}", tobj_fields);
|
||||
// println!(
|
||||
// "{:?}: {}\n",
|
||||
// tobj_fields.get("__init__").unwrap(),
|
||||
// unifier.stringify(
|
||||
// *tobj_fields.get("__init__").unwrap(),
|
||||
// &mut |id| format!("class{}", id),
|
||||
// &mut |id| format!("tvar{}", id)
|
||||
// )
|
||||
// );
|
||||
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: RefCell::new(tobj_fields),
|
||||
|
|
|
@ -86,6 +86,7 @@ pub fn impl_binop(
|
|||
};
|
||||
for op in ops {
|
||||
fields.borrow_mut().insert(binop_name(op).into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: ret_ty,
|
||||
|
@ -97,10 +98,13 @@ pub fn impl_binop(
|
|||
}],
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
)),
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
fields.borrow_mut().insert(binop_assign_name(op).into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: store.none,
|
||||
|
@ -112,7 +116,9 @@ pub fn impl_binop(
|
|||
}],
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
)),
|
||||
false,
|
||||
)
|
||||
});
|
||||
}
|
||||
} else {
|
||||
|
@ -131,9 +137,12 @@ pub fn impl_unaryop(
|
|||
for op in ops {
|
||||
fields.borrow_mut().insert(
|
||||
unaryop_name(op).into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(),
|
||||
)),
|
||||
false,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
|
@ -152,6 +161,7 @@ pub fn impl_cmpop(
|
|||
for op in ops {
|
||||
fields.borrow_mut().insert(
|
||||
comparison_name(op).unwrap().into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(
|
||||
FunSignature {
|
||||
ret: store.bool,
|
||||
|
@ -164,6 +174,8 @@ pub fn impl_cmpop(
|
|||
}
|
||||
.into(),
|
||||
)),
|
||||
false,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -10,7 +10,7 @@ use itertools::izip;
|
|||
use nac3parser::ast::{
|
||||
self,
|
||||
fold::{self, Fold},
|
||||
Arguments, Comprehension, ExprKind, Located, Location, StrRef,
|
||||
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -77,11 +77,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||
fn fold_stmt(
|
||||
&mut self,
|
||||
mut node: ast::Stmt<()>,
|
||||
) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||
let stmt = match node.node {
|
||||
// we don't want fold over type annotation
|
||||
ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment } => {
|
||||
ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => {
|
||||
self.infer_pattern(&target)?;
|
||||
// fix parser problem...
|
||||
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
||||
*ctx = ExprContext::Store;
|
||||
}
|
||||
|
||||
let target = Box::new(self.fold_expr(*target)?);
|
||||
let value = if let Some(v) = value {
|
||||
let ty = Box::new(self.fold_expr(*v)?);
|
||||
|
@ -105,14 +113,25 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
Located {
|
||||
location: node.location,
|
||||
custom: None,
|
||||
node: ast::StmtKind::AnnAssign { target, annotation, value, simple, config_comment },
|
||||
node: ast::StmtKind::AnnAssign {
|
||||
target,
|
||||
annotation,
|
||||
value,
|
||||
simple,
|
||||
config_comment,
|
||||
},
|
||||
}
|
||||
}
|
||||
ast::StmtKind::For { ref target, .. } => {
|
||||
self.infer_pattern(target)?;
|
||||
fold::fold_stmt(self, node)?
|
||||
}
|
||||
ast::StmtKind::Assign { ref targets, ref config_comment, .. } => {
|
||||
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
|
||||
for target in targets.iter_mut() {
|
||||
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
||||
*ctx = ExprContext::Store;
|
||||
}
|
||||
}
|
||||
if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) {
|
||||
if let ast::StmtKind::Assign { targets, value, .. } = node.node {
|
||||
let value = self.fold_expr(*value)?;
|
||||
|
@ -158,7 +177,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
targets,
|
||||
value: Box::new(value),
|
||||
type_comment: None,
|
||||
config_comment: config_comment.clone()
|
||||
config_comment: config_comment.clone(),
|
||||
},
|
||||
custom: None,
|
||||
});
|
||||
|
@ -199,7 +218,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
}
|
||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||
ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Pass { .. } => {}
|
||||
ast::StmtKind::Break { .. }
|
||||
| ast::StmtKind::Continue { .. }
|
||||
| ast::StmtKind::Pass { .. } => {}
|
||||
ast::StmtKind::With { items, .. } => {
|
||||
for item in items.iter() {
|
||||
let ty = item.context_expr.custom.unwrap();
|
||||
|
@ -209,17 +230,21 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
let fields = fields.borrow();
|
||||
fast_path = true;
|
||||
if let Some(enter) = fields.get(&"__enter__".into()).cloned() {
|
||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) {
|
||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) {
|
||||
let signature = signature.borrow();
|
||||
if !signature.args.is_empty() {
|
||||
return report_error(
|
||||
"__enter__ method should take no argument other than self",
|
||||
stmt.location
|
||||
)
|
||||
stmt.location,
|
||||
);
|
||||
}
|
||||
if let Some(var) = &item.optional_vars {
|
||||
if signature.vars.is_empty() {
|
||||
self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?;
|
||||
self.unify(
|
||||
signature.ret,
|
||||
var.custom.unwrap(),
|
||||
&stmt.location,
|
||||
)?;
|
||||
} else {
|
||||
fast_path = false;
|
||||
}
|
||||
|
@ -230,17 +255,17 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
} else {
|
||||
return report_error(
|
||||
"__enter__ method is required for context manager",
|
||||
stmt.location
|
||||
stmt.location,
|
||||
);
|
||||
}
|
||||
if let Some(exit) = fields.get(&"__exit__".into()).cloned() {
|
||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) {
|
||||
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) {
|
||||
let signature = signature.borrow();
|
||||
if !signature.args.is_empty() {
|
||||
return report_error(
|
||||
"__exit__ method should take no argument other than self",
|
||||
stmt.location
|
||||
)
|
||||
stmt.location,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
fast_path = false;
|
||||
|
@ -248,26 +273,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
} else {
|
||||
return report_error(
|
||||
"__exit__ method is required for context manager",
|
||||
stmt.location
|
||||
stmt.location,
|
||||
);
|
||||
}
|
||||
}
|
||||
if !fast_path {
|
||||
let enter = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||
args: vec![],
|
||||
ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()),
|
||||
vars: Default::default()
|
||||
ret: item.optional_vars.as_ref().map_or_else(
|
||||
|| self.unifier.get_fresh_var().0,
|
||||
|var| var.custom.unwrap(),
|
||||
),
|
||||
vars: Default::default(),
|
||||
}));
|
||||
let enter = self.unifier.add_ty(enter);
|
||||
let exit = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||
args: vec![],
|
||||
ret: self.unifier.get_fresh_var().0,
|
||||
vars: Default::default()
|
||||
vars: Default::default(),
|
||||
}));
|
||||
let exit = self.unifier.add_ty(exit);
|
||||
let mut fields = HashMap::new();
|
||||
fields.insert("__enter__".into(), enter);
|
||||
fields.insert("__exit__".into(), exit);
|
||||
fields.insert("__enter__".into(), (enter, false));
|
||||
fields.insert("__exit__".into(), (exit, false));
|
||||
let record = self.unifier.add_record(fields);
|
||||
self.unify(ty, record, &stmt.location)?;
|
||||
}
|
||||
|
@ -330,8 +358,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
||||
Some(self.infer_attribute(value, *attr)?)
|
||||
ast::ExprKind::Attribute { value, attr, ctx } => {
|
||||
Some(self.infer_attribute(value, *attr, ctx)?)
|
||||
}
|
||||
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||
ast::ExprKind::BinOp { left, op, right } => {
|
||||
|
@ -399,7 +427,8 @@ impl<'a> Inferencer<'a> {
|
|||
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
||||
if class_params.borrow().is_empty() {
|
||||
if let Some(ty) = fields.borrow().get(&method) {
|
||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) {
|
||||
let ty = ty.0;
|
||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
|
||||
let sign = sign.borrow();
|
||||
if sign.vars.is_empty() {
|
||||
let call = Call {
|
||||
|
@ -419,7 +448,7 @@ impl<'a> Inferencer<'a> {
|
|||
.rev()
|
||||
.collect();
|
||||
self.unifier
|
||||
.unify_call(&call, *ty, &sign, &required)
|
||||
.unify_call(&call, ty, &sign, &required)
|
||||
.map_err(|old| format!("{} at {}", old, location))?;
|
||||
return Ok(sign.ret);
|
||||
}
|
||||
|
@ -437,7 +466,7 @@ impl<'a> Inferencer<'a> {
|
|||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
let fields = once((method, call)).collect();
|
||||
let fields = once((method, (call, false))).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
self.constrain(obj, record, &location)?;
|
||||
Ok(ret)
|
||||
|
@ -538,7 +567,11 @@ impl<'a> Inferencer<'a> {
|
|||
let target = new_context.fold_expr(*generator.target)?;
|
||||
let iter = new_context.fold_expr(*generator.iter)?;
|
||||
if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) {
|
||||
new_context.unify(target.custom.unwrap(), new_context.primitives.int32, &target.location)?;
|
||||
new_context.unify(
|
||||
target.custom.unwrap(),
|
||||
new_context.primitives.int32,
|
||||
&target.location,
|
||||
)?;
|
||||
} else {
|
||||
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
|
||||
|
@ -755,13 +788,28 @@ impl<'a> Inferencer<'a> {
|
|||
&mut self,
|
||||
value: &ast::Expr<Option<Type>>,
|
||||
attr: StrRef,
|
||||
ctx: &ExprContext,
|
||||
) -> InferenceResult {
|
||||
let ty = value.custom.unwrap();
|
||||
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
|
||||
// just a fast path
|
||||
let fields = fields.borrow();
|
||||
match (fields.get(&attr), ctx == &ExprContext::Store) {
|
||||
(Some((ty, true)), _) => Ok(*ty),
|
||||
(Some((ty, false)), false) => Ok(*ty),
|
||||
(Some((_, false)), true) => {
|
||||
report_error(&format!("Field {} should be immutable", attr), value.location)
|
||||
}
|
||||
(None, _) => report_error(&format!("No such field {}", attr), value.location),
|
||||
}
|
||||
} else {
|
||||
let (attr_ty, _) = self.unifier.get_fresh_var();
|
||||
let fields = once((attr, attr_ty)).collect();
|
||||
let fields = once((attr, (attr_ty, ctx == &ExprContext::Store))).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
self.constrain(value.custom.unwrap(), record, &value.location)?;
|
||||
Ok(attr_ty)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {
|
||||
let b = self.primitives.bool;
|
||||
|
|
|
@ -8,8 +8,8 @@ use crate::{
|
|||
use indoc::indoc;
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::zip;
|
||||
use parking_lot::RwLock;
|
||||
use nac3parser::parser::parse_program;
|
||||
use parking_lot::RwLock;
|
||||
use test_case::test_case;
|
||||
|
||||
struct Resolver {
|
||||
|
@ -75,7 +75,7 @@ impl TestEnvironment {
|
|||
}
|
||||
.into(),
|
||||
));
|
||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
||||
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
|
||||
}
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
|
@ -170,7 +170,7 @@ impl TestEnvironment {
|
|||
}
|
||||
.into(),
|
||||
));
|
||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
||||
fields.borrow_mut().insert("__add__".into(), (add_ty, false));
|
||||
}
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
|
@ -203,7 +203,9 @@ impl TestEnvironment {
|
|||
params: HashMap::new().into(),
|
||||
});
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate() {
|
||||
for (i, name) in
|
||||
["int32", "int64", "float", "bool", "none", "range", "str"].iter().enumerate()
|
||||
{
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
name: (*name).into(),
|
||||
|
@ -225,7 +227,7 @@ impl TestEnvironment {
|
|||
|
||||
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(7),
|
||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
});
|
||||
top_level_defs.push(
|
||||
|
@ -233,7 +235,7 @@ impl TestEnvironment {
|
|||
name: "Foo".into(),
|
||||
object_id: DefinitionId(7),
|
||||
type_vars: vec![v0],
|
||||
fields: [("a".into(), v0)].into(),
|
||||
fields: [("a".into(), v0, true)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
|
@ -259,7 +261,7 @@ impl TestEnvironment {
|
|||
));
|
||||
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(8),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)]
|
||||
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
|
@ -271,7 +273,7 @@ impl TestEnvironment {
|
|||
name: "Bar".into(),
|
||||
object_id: DefinitionId(8),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)].into(),
|
||||
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
|
@ -288,7 +290,7 @@ impl TestEnvironment {
|
|||
|
||||
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(9),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)]
|
||||
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
|
@ -300,7 +302,7 @@ impl TestEnvironment {
|
|||
name: "Bar2".into(),
|
||||
object_id: DefinitionId(9),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)].into(),
|
||||
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
|
|
|
@ -49,7 +49,7 @@ pub struct FunSignature {
|
|||
pub enum TypeVarMeta {
|
||||
Generic,
|
||||
Sequence(RefCell<Mapping<i32>>),
|
||||
Record(RefCell<Mapping<StrRef>>),
|
||||
Record(RefCell<Mapping<StrRef, (Type, bool)>>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -71,7 +71,7 @@ pub enum TypeEnum {
|
|||
},
|
||||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: RefCell<Mapping<StrRef>>,
|
||||
fields: RefCell<Mapping<StrRef, (Type, bool)>>,
|
||||
params: RefCell<VarMap>,
|
||||
},
|
||||
TVirtual {
|
||||
|
@ -155,7 +155,7 @@ impl Unifier {
|
|||
self.unification_table.new_key(Rc::new(a))
|
||||
}
|
||||
|
||||
pub fn add_record(&mut self, fields: Mapping<StrRef>) -> Type {
|
||||
pub fn add_record(&mut self, fields: Mapping<StrRef, (Type, bool)>) -> Type {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
self.add_ty(TypeEnum::TVar {
|
||||
|
@ -394,11 +394,12 @@ impl Unifier {
|
|||
}
|
||||
(Record(fields1), Record(fields2)) => {
|
||||
let mut fields2 = fields2.borrow_mut();
|
||||
for (key, value) in fields1.borrow().iter() {
|
||||
if let Some(ty) = fields2.get(key) {
|
||||
self.unify_impl(*ty, *value, false)?;
|
||||
for (key, (ty, is_mutable)) in fields1.borrow().iter() {
|
||||
if let Some((ty2, is_mutable2)) = fields2.get_mut(key) {
|
||||
self.unify_impl(*ty2, *ty, false)?;
|
||||
*is_mutable2 |= *is_mutable;
|
||||
} else {
|
||||
fields2.insert(*key, *value);
|
||||
fields2.insert(*key, (*ty, *is_mutable));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -495,13 +496,19 @@ impl Unifier {
|
|||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields
|
||||
for (k, (ty, is_mutable)) in map.borrow().iter() {
|
||||
let (ty2, is_mutable2) = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
self.unify_impl(ty, *v, false)?;
|
||||
// typevar represents the usage of the variable
|
||||
// it is OK to have immutable usage for mutable fields
|
||||
// but cannot have mutable usage for immutable fields
|
||||
if *is_mutable && !is_mutable2 {
|
||||
return Err(format!("Field {} should be immutable", k));
|
||||
}
|
||||
self.unify_impl(*ty, ty2, false)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify_impl(x, b, false)?;
|
||||
|
@ -510,16 +517,19 @@ impl Unifier {
|
|||
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
||||
let ty = self.get_ty(*ty);
|
||||
if let TObj { fields, .. } = ty.as_ref() {
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields
|
||||
for (k, (ty, is_mutable)) in map.borrow().iter() {
|
||||
let (ty2, is_mutable2) = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
||||
if !matches!(self.get_ty(ty2).as_ref(), TFunc { .. }) {
|
||||
return Err(format!("Cannot access field {} for virtual type", k));
|
||||
}
|
||||
self.unify_impl(*v, ty, false)?;
|
||||
if *is_mutable && !is_mutable2 {
|
||||
return Err(format!("Field {} should be immutable", k));
|
||||
}
|
||||
self.unify_impl(*ty, ty2, false)?;
|
||||
}
|
||||
} else {
|
||||
// require annotation...
|
||||
|
@ -643,7 +653,9 @@ impl Unifier {
|
|||
let fields = fields
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)))
|
||||
.map(|(k, (v, _))| {
|
||||
format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))
|
||||
})
|
||||
.join(", ");
|
||||
format!("record[{}]", fields)
|
||||
}
|
||||
|
@ -805,7 +817,7 @@ impl Unifier {
|
|||
let params =
|
||||
self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
||||
let fields = self
|
||||
.subst_map(&fields.borrow(), mapping, cache)
|
||||
.subst_map2(&fields.borrow(), mapping, cache)
|
||||
.unwrap_or_else(|| fields.borrow().clone());
|
||||
let new_ty = self.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
|
@ -873,6 +885,27 @@ impl Unifier {
|
|||
map2
|
||||
}
|
||||
|
||||
fn subst_map2<K>(
|
||||
&mut self,
|
||||
map: &Mapping<K, (Type, bool)>,
|
||||
mapping: &VarMap,
|
||||
cache: &mut HashMap<Type, Option<Type>>,
|
||||
) -> Option<Mapping<K, (Type, bool)>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, (v, mutability)) in map.iter() {
|
||||
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||
if map2.is_none() {
|
||||
map2 = Some(map.clone());
|
||||
}
|
||||
*map2.as_mut().unwrap().get_mut(k).unwrap() = (v1, *mutability);
|
||||
}
|
||||
}
|
||||
map2
|
||||
}
|
||||
|
||||
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
||||
use TypeEnum::*;
|
||||
let x = self.get_ty(a);
|
||||
|
|
|
@ -39,7 +39,7 @@ impl Unifier {
|
|||
(
|
||||
TypeEnum::TVar { meta: Record(fields1), .. },
|
||||
TypeEnum::TVar { meta: Record(fields2), .. },
|
||||
) => self.map_eq(&fields1.borrow(), &fields2.borrow()),
|
||||
) => self.map_eq2(&fields1.borrow(), &fields2.borrow()),
|
||||
(
|
||||
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||
|
@ -63,6 +63,25 @@ impl Unifier {
|
|||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn map_eq2<K>(
|
||||
&mut self,
|
||||
map1: &Mapping<K, (Type, bool)>,
|
||||
map2: &Mapping<K, (Type, bool)>,
|
||||
) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
if map1.len() != map2.len() {
|
||||
return false;
|
||||
}
|
||||
for (k, (ty1, m1)) in map1.iter() {
|
||||
if !map2.get(k).map(|(ty2, m2)| m1 == m2 && self.eq(*ty1, *ty2)).unwrap_or(false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
|
@ -104,7 +123,11 @@ impl TestEnvironment {
|
|||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
fields: [("a".into(), (v0, true))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
}),
|
||||
);
|
||||
|
@ -151,7 +174,7 @@ impl TestEnvironment {
|
|||
let eq = s.find('=').unwrap();
|
||||
let key = s[1..eq].into();
|
||||
let result = self.internal_parse(&s[eq + 1..], mapping);
|
||||
fields.insert(key, result.0);
|
||||
fields.insert(key, (result.0, true));
|
||||
s = result.1;
|
||||
}
|
||||
(self.unifier.add_record(fields), &s[1..])
|
||||
|
@ -326,7 +349,7 @@ fn test_recursive_subst() {
|
|||
let foo_ty = env.unifier.get_ty(foo_id);
|
||||
let mapping: HashMap<_, _>;
|
||||
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
|
||||
fields.borrow_mut().insert("rec".into(), foo_id);
|
||||
fields.borrow_mut().insert("rec".into(), (foo_id, true));
|
||||
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
|
||||
} else {
|
||||
unreachable!()
|
||||
|
@ -335,8 +358,8 @@ fn test_recursive_subst() {
|
|||
let instantiated_ty = env.unifier.get_ty(instantiated);
|
||||
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
||||
let fields = fields.borrow();
|
||||
assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int));
|
||||
assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated));
|
||||
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
|
||||
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
|
@ -351,7 +374,7 @@ fn test_virtual() {
|
|||
));
|
||||
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
fields: [("f".into(), fun), ("a".into(), int)]
|
||||
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<StrRef, _>>()
|
||||
|
@ -363,15 +386,15 @@ fn test_virtual() {
|
|||
|
||||
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
||||
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
||||
let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect());
|
||||
let c = env.unifier.add_record([("f".into(), (v1, false))].iter().cloned().collect());
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
env.unifier.unify(b, c).unwrap();
|
||||
assert!(env.unifier.eq(v1, fun));
|
||||
|
||||
let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect());
|
||||
let d = env.unifier.add_record([("a".into(), (v1, true))].iter().cloned().collect());
|
||||
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
|
||||
|
||||
let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect());
|
||||
let d = env.unifier.add_record([("b".into(), (v1, true))].iter().cloned().collect());
|
||||
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue