forked from M-Labs/nac3
1
0
Fork 0

add RefCell to FunSignature in TypeEnum

This commit is contained in:
ychenfo 2021-08-16 13:49:10 +08:00
parent d8c3c063ec
commit 3734663188
6 changed files with 59 additions and 40 deletions

View File

@ -1,6 +1,5 @@
use std::borrow::{Borrow, BorrowMut}; use std::borrow::BorrowMut;
use std::collections::HashSet; use std::{collections::HashMap, collections::HashSet, sync::Arc};
use std::{collections::HashMap, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
@ -52,6 +51,16 @@ pub enum TopLevelDef {
}, },
} }
impl TopLevelDef {
fn get_function_type(&self) -> Result<Type, String> {
if let Self::Function { signature, .. } = self {
Ok(*signature)
} else {
Err("only expect function def here".into())
}
}
}
pub struct TopLevelContext { pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>, pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>, pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
@ -219,18 +228,15 @@ impl TopLevelComposer {
let fun_name = Self::name_mangling(class_name.clone(), name); let fun_name = Self::name_mangling(class_name.clone(), name);
let def_id = def_list.len(); let def_id = def_list.len();
// add to unifier
let ty = self.unifier.write().add_ty(TypeEnum::TFunc(FunSignature {
args: Default::default(),
ret: self.primitives.none,
vars: Default::default(),
}));
// add to the definition list // add to the definition list
def_list.push( def_list.push(
Self::make_top_level_function_def( Self::make_top_level_function_def(
fun_name.clone(), fun_name.clone(),
ty, self.unifier.write().add_ty(TypeEnum::TFunc(FunSignature {
args: Default::default(),
ret: self.primitives.none.into(),
vars: Default::default(),
}.into())),
resolver.clone(), resolver.clone(),
) )
.into(), .into(),
@ -309,7 +315,7 @@ impl TopLevelComposer {
} else { continue } } else { continue }
}; };
let mut generic_occured = false; let mut is_generic = false;
for b in class_bases { for b in class_bases {
match &b.node { match &b.node {
// analyze typevars bounded to the class, // analyze typevars bounded to the class,
@ -321,8 +327,8 @@ impl TopLevelComposer {
// can only be `Generic[...]` and this can only appear once // can only be `Generic[...]` and this can only appear once
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "Generic" { if id == "Generic" {
if !generic_occured { if !is_generic {
generic_occured = true; is_generic = true;
true true
} else { } else {
return Err("Only single Generic[...] can be in bases".into()) return Err("Only single Generic[...] can be in bases".into())
@ -467,10 +473,10 @@ impl TopLevelComposer {
while !to_be_analyzed_class.is_empty() { while !to_be_analyzed_class.is_empty() {
let ind = to_be_analyzed_class.remove(0).0; let ind = to_be_analyzed_class.remove(0).0;
let (class_def, class_ast) = ( let (class_def, class_ast) = (
&mut def_list[ind], &ast_list[ind] &mut def_list[ind], &ast_list[ind]
); );
let ( let (
class_name, class_name,
class_fields, class_fields,
@ -491,7 +497,10 @@ impl TopLevelComposer {
}, .. }) = class_ast { }, .. }) = class_ast {
(name, fields, methods, resolver, body) (name, fields, methods, resolver, body)
} else { unreachable!("must be both class") } } else { unreachable!("must be both class") }
} else { continue } } else {
to_be_analyzed_class.push(DefinitionId(ind));
continue
}
}; };
for b in class_body { for b in class_body {
if let ast::StmtKind::FunctionDef { if let ast::StmtKind::FunctionDef {
@ -508,13 +517,19 @@ impl TopLevelComposer {
class_name.into(), class_name.into(),
func_name) func_name)
).unwrap(); ).unwrap();
let method_def = def_list[method_def_id.0].write();
let a = &def_list[method_def_id.0]; let method_ty = method_def.get_function_type()?;
let method_signature = unifier.get_ty(method_ty);
if let TypeEnum::TFunc(sig) = method_signature.as_ref() {
let mut sig = &mut *sig.borrow_mut();
} else { unreachable!() }
} else { } else {
// what should we do with `class A: a = 3`? // what should we do with `class A: a = 3`?
continue continue
} }
} }
} }
Ok(()) Ok(())

View File

@ -84,7 +84,7 @@ pub fn impl_binop(
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { ty: other, default_value: None, name: "other".into() }], args: vec![FuncArg { ty: other, default_value: None, name: "other".into() }],
})) }.into()))
}); });
fields.borrow_mut().insert(binop_assign_name(op).into(), { fields.borrow_mut().insert(binop_assign_name(op).into(), {
@ -97,7 +97,7 @@ pub fn impl_binop(
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { ty: other, default_value: None, name: "other".into() }], args: vec![FuncArg { ty: other, default_value: None, name: "other".into() }],
})) }.into()))
}); });
} }
} else { } else {
@ -120,7 +120,7 @@ pub fn impl_unaryop(
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![], args: vec![],
})), }.into())),
); );
} }
} else { } else {
@ -143,7 +143,7 @@ pub fn impl_cmpop(
ret: store.bool, ret: store.bool,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { ty: other_ty, default_value: None, name: "other".into() }], args: vec![FuncArg { ty: other_ty, default_value: None, name: "other".into() }],
})), }.into())),
); );
} }
} else { } else {

View File

@ -258,7 +258,7 @@ impl<'a> Inferencer<'a> {
Ok(Located { Ok(Located {
location, location,
node: ExprKind::Lambda { args: args.into(), body: body.into() }, node: ExprKind::Lambda { args: args.into(), body: body.into() },
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))), custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))),
}) })
} }

View File

@ -180,14 +180,14 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(), vars: [(id, v0)].iter().cloned().collect(),
})), }.into())),
); );
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature { let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: int32, ret: int32,
vars: Default::default(), vars: Default::default(),
})); }.into()));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: DefinitionId(6),
fields: [("a".into(), int32), ("b".into(), fun)] fields: [("a".into(), int32), ("b".into(), fun)]
@ -211,7 +211,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: bar, ret: bar,
vars: Default::default(), vars: Default::default(),
})), }.into())),
); );
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
@ -237,7 +237,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: bar2, ret: bar2,
vars: Default::default(), vars: Default::default(),
})), }.into())),
); );
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();

View File

@ -1,5 +1,5 @@
use itertools::{chain, zip, Itertools}; use itertools::{chain, zip, Itertools};
use std::borrow::Cow; use std::borrow::{Borrow, Cow};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::once; use std::iter::once;
@ -77,7 +77,7 @@ pub enum TypeEnum {
ty: Type, ty: Type,
}, },
TCall(RefCell<Vec<CallId>>), TCall(RefCell<Vec<CallId>>),
TFunc(FunSignature), TFunc(RefCell<FunSignature>),
} }
impl TypeEnum { impl TypeEnum {
@ -472,7 +472,7 @@ impl Unifier {
} }
(TCall(calls), TFunc(signature)) => { (TCall(calls), TFunc(signature)) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
let required: Vec<String> = signature let required: Vec<String> = signature.borrow()
.args .args
.iter() .iter()
.filter(|v| v.default_value.is_none()) .filter(|v| v.default_value.is_none())
@ -482,7 +482,7 @@ impl Unifier {
// we unify every calls to the function signature. // we unify every calls to the function signature.
for c in calls.borrow().iter() { for c in calls.borrow().iter() {
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone(); let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
let instantiated = self.instantiate_fun(b, signature); let instantiated = self.instantiate_fun(b, &*signature.borrow());
let r = self.get_ty(instantiated); let r = self.get_ty(instantiated);
let r = r.as_ref(); let r = r.as_ref();
let signature; let signature;
@ -495,9 +495,9 @@ impl Unifier {
// arguments) are provided, and do not provide the same argument twice. // arguments) are provided, and do not provide the same argument twice.
let mut required = required.clone(); let mut required = required.clone();
let mut all_names: Vec<_> = let mut all_names: Vec<_> =
signature.args.iter().map(|v| (v.name.clone(), v.ty)).rev().collect(); signature.borrow().args.iter().map(|v| (v.name.clone(), v.ty)).rev().collect();
for (i, t) in posargs.iter().enumerate() { for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i { if signature.borrow().args.len() <= i {
return Err("Too many arguments.".to_string()); return Err("Too many arguments.".to_string());
} }
if !required.is_empty() { if !required.is_empty() {
@ -518,12 +518,13 @@ impl Unifier {
if !required.is_empty() { if !required.is_empty() {
return Err("Expected more arguments".to_string()); return Err("Expected more arguments".to_string());
} }
self.unify(*ret, signature.ret)?; self.unify(*ret, signature.borrow().ret)?;
*fun.borrow_mut() = Some(instantiated); *fun.borrow_mut() = Some(instantiated);
} }
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TFunc(sign1), TFunc(sign2)) => { (TFunc(sign1), TFunc(sign2)) => {
let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow());
if !sign1.vars.is_empty() || !sign2.vars.is_empty() { if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
return Err("Polymorphic function pointer is prohibited.".to_string()); return Err("Polymorphic function pointer is prohibited.".to_string());
} }
@ -604,13 +605,14 @@ impl Unifier {
TypeEnum::TCall { .. } => "call".to_owned(), TypeEnum::TCall { .. } => "call".to_owned(),
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
let params = signature let params = signature
.borrow()
.args .args
.iter() .iter()
.map(|arg| { .map(|arg| {
format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name)) format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name))
}) })
.join(", "); .join(", ");
let ret = self.stringify(signature.ret, obj_to_name, var_to_name); let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name);
format!("fn[[{}], {}]", params, ret) format!("fn[[{}], {}]", params, ret)
} }
} }
@ -723,7 +725,8 @@ impl Unifier {
None None
} }
} }
TypeEnum::TFunc(FunSignature { args, ret, vars: params }) => { TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
let new_params = self.subst_map(params, mapping); let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping); let new_ret = self.subst(*ret, mapping);
let mut new_args = Cow::from(args); let mut new_args = Cow::from(args);
@ -738,7 +741,7 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone()); let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret); let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.into_owned(); let args = new_args.into_owned();
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params }))) Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params }.into())))
} else { } else {
None None
} }
@ -809,7 +812,8 @@ impl Unifier {
self.occur_check(a, *t)?; self.occur_check(a, *t)?;
} }
} }
TypeEnum::TFunc(FunSignature { args, ret, vars: params }) => { TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) { for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
self.occur_check(a, *t)?; self.occur_check(a, *t)?;
} }

View File

@ -333,7 +333,7 @@ fn test_virtual() {
args: vec![], args: vec![],
ret: int, ret: int,
vars: HashMap::new(), vars: HashMap::new(),
})); }.into()));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: [("f".to_string(), fun), ("a".to_string(), int)] fields: [("f".to_string(), fun), ("a".to_string(), int)]