hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
6 changed files with 50 additions and 38 deletions
Showing only changes of commit cb01c79603 - Show all commits

View File

@ -49,7 +49,7 @@ pub struct CodeGenTask {
pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>,
pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver>,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
}
fn get_llvm_type<'ctx>(
@ -108,7 +108,7 @@ pub fn gen_func<'ctx>(
module: Module<'ctx>,
task: CodeGenTask,
top_level_ctx: Arc<TopLevelContext>,
) -> Module<'ctx> {
) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read();
@ -199,5 +199,7 @@ pub fn gen_func<'ctx>(
code_gen_context.gen_stmt(stmt);
}
code_gen_context.module
let CodeGenContext { builder, module, .. } = code_gen_context;
(builder, module)
}

View File

@ -6,7 +6,7 @@ use crate::{
typecheck::{
magic_methods::set_primitives_magic_methods,
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier},
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indoc::indoc;
@ -48,7 +48,7 @@ struct TestEnvironment {
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
@ -102,7 +102,7 @@ impl TestEnvironment {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver>;
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
@ -239,7 +239,7 @@ fn test_primitives() {
}
"}
.trim();
let ir = module.print_to_string().to_string();
let ir = module.1.print_to_string().to_string();
println!("src:\n{}", source);
println!("IR:\n{}", ir);
assert_eq!(expected, ir.trim());

View File

@ -161,19 +161,7 @@ pub fn parse_type_annotation<T>(
}
}
impl dyn SymbolResolver {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
parse_type_annotation(self, top_level, unifier, primitives, expr)
}
}
impl dyn SymbolResolver + Send {
impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,

View File

@ -3,8 +3,8 @@ use std::convert::{From, TryInto};
use std::iter::once;
use std::{cell::RefCell, sync::Arc};
use super::magic_methods::*;
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
use itertools::izip;
use rustpython_parser::ast::{
@ -38,7 +38,7 @@ pub struct PrimitiveStore {
}
pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver>,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}
@ -50,7 +50,7 @@ pub struct Inferencer<'a> {
pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut HashMap<CodeLocation, Arc<Call>>,
pub calls: &'a mut HashMap<CodeLocation, CallId>,
}
struct NaiveFolder();
@ -190,13 +190,13 @@ impl<'a> Inferencer<'a> {
params: Vec<Type>,
ret: Type,
) -> InferenceResult {
let call = Arc::new(Call {
let call = self.unifier.add_call(Call {
posargs: params,
kwargs: HashMap::new(),
ret,
fun: RefCell::new(None),
});
self.calls.insert(location.into(), call.clone());
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields);
@ -398,7 +398,7 @@ impl<'a> Inferencer<'a> {
.map(|v| fold::fold_keyword(self, v))
.collect::<Result<Vec<_>, _>>()?;
let ret = self.unifier.get_fresh_var().0;
let call = Arc::new(Call {
let call = self.unifier.add_call(Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords
.iter()
@ -407,7 +407,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None),
ret,
});
self.calls.insert(location.into(), call.clone());
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?;

View File

@ -40,7 +40,7 @@ struct TestEnvironment {
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
@ -94,7 +94,7 @@ impl TestEnvironment {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver>;
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
top_level: TopLevelContext {
@ -273,7 +273,7 @@ impl TestEnvironment {
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver>;
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,

View File

@ -16,6 +16,9 @@ mod test;
/// Handle for a type, implementated as a key in the unification table.
pub type Type = UnificationKey;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CallId(usize);
pub type Mapping<K, V = Type> = HashMap<K, V>;
type VarMap = Mapping<u32>;
@ -73,7 +76,7 @@ pub enum TypeEnum {
TVirtual {
ty: Type,
},
TCall(RefCell<Vec<Arc<Call>>>),
TCall(RefCell<Vec<CallId>>),
TFunc(FunSignature),
}
@ -92,17 +95,18 @@ impl TypeEnum {
}
}
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>;
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
var_id: u32,
}
impl Unifier {
/// Get an empty unifier
pub fn new() -> Unifier {
Unifier { unification_table: UnificationTable::new(), var_id: 0 }
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
}
/// Determine if the two types are the same
@ -112,11 +116,19 @@ impl Unifier {
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap();
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 }
Unifier {
unification_table: UnificationTable::from_send(&lock.0),
var_id: lock.1,
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
}
}
pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id)))
Arc::new(Mutex::new((
self.unification_table.get_send(),
self.var_id,
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
)))
}
/// Register a type to the unifier.
@ -135,6 +147,12 @@ impl Unifier {
})
}
pub fn add_call(&mut self, call: Call) -> CallId {
let id = CallId(self.calls.len());
self.calls.push(Rc::new(call));
id
}
pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty)
}
@ -463,11 +481,11 @@ impl Unifier {
.collect();
// we unify every calls to the function signature.
for c in calls.borrow().iter() {
let Call { posargs, kwargs, ret, fun } = c.as_ref();
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
let instantiated = self.instantiate_fun(b, signature);
let signature;
let r = self.get_ty(instantiated);
let r = r.as_ref();
let signature;
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else {
@ -765,10 +783,14 @@ impl Unifier {
}
}
TypeEnum::TCall(calls) => {
let call_store = self.calls.clone();
for t in calls
.borrow()
.iter()
.map(|call| chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)))
.map(|call| {
let call = call_store[call.0].as_ref();
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
})
.flatten()
{
self.occur_check(a, *t)?;