hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
6 changed files with 50 additions and 38 deletions
Showing only changes of commit cb01c79603 - Show all commits

View File

@ -49,7 +49,7 @@ pub struct CodeGenTask {
pub signature: FunSignature, pub signature: FunSignature,
pub body: Vec<Stmt<Option<Type>>>, pub body: Vec<Stmt<Option<Type>>>,
pub unifier_index: usize, pub unifier_index: usize,
pub resolver: Arc<dyn SymbolResolver>, pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
} }
fn get_llvm_type<'ctx>( fn get_llvm_type<'ctx>(
@ -108,7 +108,7 @@ pub fn gen_func<'ctx>(
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
) -> Module<'ctx> { ) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon // unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = { let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read(); let unifiers = top_level_ctx.unifiers.read();
@ -199,5 +199,7 @@ pub fn gen_func<'ctx>(
code_gen_context.gen_stmt(stmt); code_gen_context.gen_stmt(stmt);
} }
code_gen_context.module let CodeGenContext { builder, module, .. } = code_gen_context;
(builder, module)
} }

View File

@ -6,7 +6,7 @@ use crate::{
typecheck::{ typecheck::{
magic_methods::set_primitives_magic_methods, magic_methods::set_primitives_magic_methods,
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore}, type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
}, },
}; };
use indoc::indoc; use indoc::indoc;
@ -48,7 +48,7 @@ struct TestEnvironment {
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>, pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>, pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext, pub top_level: TopLevelContext,
} }
@ -102,7 +102,7 @@ impl TestEnvironment {
id_to_type: identifier_mapping.clone(), id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(), id_to_def: Default::default(),
class_names: Default::default(), class_names: Default::default(),
}) as Arc<dyn SymbolResolver>; }) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment { TestEnvironment {
unifier, unifier,
@ -239,7 +239,7 @@ fn test_primitives() {
} }
"} "}
.trim(); .trim();
let ir = module.print_to_string().to_string(); let ir = module.1.print_to_string().to_string();
println!("src:\n{}", source); println!("src:\n{}", source);
println!("IR:\n{}", ir); println!("IR:\n{}", ir);
assert_eq!(expected, ir.trim()); assert_eq!(expected, ir.trim());

View File

@ -161,19 +161,7 @@ pub fn parse_type_annotation<T>(
} }
} }
impl dyn SymbolResolver { impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
parse_type_annotation(self, top_level, unifier, primitives, expr)
}
}
impl dyn SymbolResolver + Send {
pub fn parse_type_annotation<T>( pub fn parse_type_annotation<T>(
&self, &self,
top_level: &TopLevelContext, top_level: &TopLevelContext,

View File

@ -3,8 +3,8 @@ use std::convert::{From, TryInto};
use std::iter::once; use std::iter::once;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::magic_methods::*;
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext}; use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
use itertools::izip; use itertools::izip;
use rustpython_parser::ast::{ use rustpython_parser::ast::{
@ -38,7 +38,7 @@ pub struct PrimitiveStore {
} }
pub struct FunctionData { pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver>, pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub return_type: Option<Type>, pub return_type: Option<Type>,
pub bound_variables: Vec<Type>, pub bound_variables: Vec<Type>,
} }
@ -50,7 +50,7 @@ pub struct Inferencer<'a> {
pub primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>, pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>, pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut HashMap<CodeLocation, Arc<Call>>, pub calls: &'a mut HashMap<CodeLocation, CallId>,
} }
struct NaiveFolder(); struct NaiveFolder();
@ -190,13 +190,13 @@ impl<'a> Inferencer<'a> {
params: Vec<Type>, params: Vec<Type>,
ret: Type, ret: Type,
) -> InferenceResult { ) -> InferenceResult {
let call = Arc::new(Call { let call = self.unifier.add_call(Call {
posargs: params, posargs: params,
kwargs: HashMap::new(), kwargs: HashMap::new(),
ret, ret,
fun: RefCell::new(None), fun: RefCell::new(None),
}); });
self.calls.insert(location.into(), call.clone()); self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect(); let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields); let record = self.unifier.add_record(fields);
@ -398,7 +398,7 @@ impl<'a> Inferencer<'a> {
.map(|v| fold::fold_keyword(self, v)) .map(|v| fold::fold_keyword(self, v))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let ret = self.unifier.get_fresh_var().0; let ret = self.unifier.get_fresh_var().0;
let call = Arc::new(Call { let call = self.unifier.add_call(Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(), posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords kwargs: keywords
.iter() .iter()
@ -407,7 +407,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None), fun: RefCell::new(None),
ret, ret,
}); });
self.calls.insert(location.into(), call.clone()); self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?; self.unifier.unify(func.custom.unwrap(), call)?;

View File

@ -40,7 +40,7 @@ struct TestEnvironment {
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>, pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>, pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext, pub top_level: TopLevelContext,
} }
@ -94,7 +94,7 @@ impl TestEnvironment {
id_to_type: identifier_mapping.clone(), id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(), id_to_def: Default::default(),
class_names: Default::default(), class_names: Default::default(),
}) as Arc<dyn SymbolResolver>; }) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment { TestEnvironment {
top_level: TopLevelContext { top_level: TopLevelContext {
@ -273,7 +273,7 @@ impl TestEnvironment {
.cloned() .cloned()
.collect(), .collect(),
class_names, class_names,
}) as Arc<dyn SymbolResolver>; }) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment { TestEnvironment {
unifier, unifier,

View File

@ -16,6 +16,9 @@ mod test;
/// Handle for a type, implementated as a key in the unification table. /// Handle for a type, implementated as a key in the unification table.
pub type Type = UnificationKey; pub type Type = UnificationKey;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CallId(usize);
pub type Mapping<K, V = Type> = HashMap<K, V>; pub type Mapping<K, V = Type> = HashMap<K, V>;
type VarMap = Mapping<u32>; type VarMap = Mapping<u32>;
@ -73,7 +76,7 @@ pub enum TypeEnum {
TVirtual { TVirtual {
ty: Type, ty: Type,
}, },
TCall(RefCell<Vec<Arc<Call>>>), TCall(RefCell<Vec<CallId>>),
TFunc(FunSignature), TFunc(FunSignature),
} }
@ -92,17 +95,18 @@ impl TypeEnum {
} }
} }
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>; pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
pub struct Unifier { pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>, unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
var_id: u32, var_id: u32,
} }
impl Unifier { impl Unifier {
/// Get an empty unifier /// Get an empty unifier
pub fn new() -> Unifier { pub fn new() -> Unifier {
Unifier { unification_table: UnificationTable::new(), var_id: 0 } Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
} }
/// Determine if the two types are the same /// Determine if the two types are the same
@ -112,11 +116,19 @@ impl Unifier {
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap(); let lock = unifier.lock().unwrap();
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 } Unifier {
unification_table: UnificationTable::from_send(&lock.0),
var_id: lock.1,
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
}
} }
pub fn get_shared_unifier(&self) -> SharedUnifier { pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id))) Arc::new(Mutex::new((
self.unification_table.get_send(),
self.var_id,
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
)))
} }
/// Register a type to the unifier. /// Register a type to the unifier.
@ -135,6 +147,12 @@ impl Unifier {
}) })
} }
pub fn add_call(&mut self, call: Call) -> CallId {
let id = CallId(self.calls.len());
self.calls.push(Rc::new(call));
id
}
pub fn get_representative(&mut self, ty: Type) -> Type { pub fn get_representative(&mut self, ty: Type) -> Type {
self.unification_table.get_representative(ty) self.unification_table.get_representative(ty)
} }
@ -463,11 +481,11 @@ impl Unifier {
.collect(); .collect();
// we unify every calls to the function signature. // we unify every calls to the function signature.
for c in calls.borrow().iter() { for c in calls.borrow().iter() {
let Call { posargs, kwargs, ret, fun } = c.as_ref(); let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
let instantiated = self.instantiate_fun(b, signature); let instantiated = self.instantiate_fun(b, signature);
let signature;
let r = self.get_ty(instantiated); let r = self.get_ty(instantiated);
let r = r.as_ref(); let r = r.as_ref();
let signature;
if let TypeEnum::TFunc(s) = &*r { if let TypeEnum::TFunc(s) = &*r {
signature = s; signature = s;
} else { } else {
@ -765,10 +783,14 @@ impl Unifier {
} }
} }
TypeEnum::TCall(calls) => { TypeEnum::TCall(calls) => {
let call_store = self.calls.clone();
for t in calls for t in calls
.borrow() .borrow()
.iter() .iter()
.map(|call| chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))) .map(|call| {
let call = call_store[call.0].as_ref();
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
})
.flatten() .flatten()
{ {
self.occur_check(a, *t)?; self.occur_check(a, *t)?;