hm-inference #6
|
@ -49,7 +49,7 @@ pub struct CodeGenTask {
|
|||
pub signature: FunSignature,
|
||||
pub body: Vec<Stmt<Option<Type>>>,
|
||||
pub unifier_index: usize,
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
}
|
||||
|
||||
fn get_llvm_type<'ctx>(
|
||||
|
@ -108,7 +108,7 @@ pub fn gen_func<'ctx>(
|
|||
module: Module<'ctx>,
|
||||
task: CodeGenTask,
|
||||
top_level_ctx: Arc<TopLevelContext>,
|
||||
) -> Module<'ctx> {
|
||||
) -> (Builder<'ctx>, Module<'ctx>) {
|
||||
// unwrap_or(0) is for unit tests without using rayon
|
||||
let (mut unifier, primitives) = {
|
||||
let unifiers = top_level_ctx.unifiers.read();
|
||||
|
@ -199,5 +199,7 @@ pub fn gen_func<'ctx>(
|
|||
code_gen_context.gen_stmt(stmt);
|
||||
}
|
||||
|
||||
code_gen_context.module
|
||||
let CodeGenContext { builder, module, .. } = code_gen_context;
|
||||
|
||||
(builder, module)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
typecheck::{
|
||||
magic_methods::set_primitives_magic_methods,
|
||||
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
||||
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use indoc::indoc;
|
||||
|
@ -48,7 +48,7 @@ struct TestEnvironment {
|
|||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ impl TestEnvironment {
|
|||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
|
@ -239,7 +239,7 @@ fn test_primitives() {
|
|||
}
|
||||
"}
|
||||
.trim();
|
||||
let ir = module.print_to_string().to_string();
|
||||
let ir = module.1.print_to_string().to_string();
|
||||
println!("src:\n{}", source);
|
||||
println!("IR:\n{}", ir);
|
||||
assert_eq!(expected, ir.trim());
|
||||
|
|
|
@ -161,19 +161,7 @@ pub fn parse_type_annotation<T>(
|
|||
}
|
||||
}
|
||||
|
||||
impl dyn SymbolResolver {
|
||||
pub fn parse_type_annotation<T>(
|
||||
&self,
|
||||
top_level: &TopLevelContext,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
parse_type_annotation(self, top_level, unifier, primitives, expr)
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn SymbolResolver + Send {
|
||||
impl dyn SymbolResolver + Send + Sync {
|
||||
pub fn parse_type_annotation<T>(
|
||||
&self,
|
||||
top_level: &TopLevelContext,
|
||||
|
|
|
@ -3,8 +3,8 @@ use std::convert::{From, TryInto};
|
|||
use std::iter::once;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use super::magic_methods::*;
|
||||
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||
use super::{magic_methods::*, typedef::CallId};
|
||||
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
||||
use itertools::izip;
|
||||
use rustpython_parser::ast::{
|
||||
|
@ -38,7 +38,7 @@ pub struct PrimitiveStore {
|
|||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||
pub return_type: Option<Type>,
|
||||
pub bound_variables: Vec<Type>,
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ pub struct Inferencer<'a> {
|
|||
pub primitives: &'a PrimitiveStore,
|
||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||
pub variable_mapping: HashMap<String, Type>,
|
||||
pub calls: &'a mut HashMap<CodeLocation, Arc<Call>>,
|
||||
pub calls: &'a mut HashMap<CodeLocation, CallId>,
|
||||
}
|
||||
|
||||
struct NaiveFolder();
|
||||
|
@ -190,13 +190,13 @@ impl<'a> Inferencer<'a> {
|
|||
params: Vec<Type>,
|
||||
ret: Type,
|
||||
) -> InferenceResult {
|
||||
let call = Arc::new(Call {
|
||||
let call = self.unifier.add_call(Call {
|
||||
posargs: params,
|
||||
kwargs: HashMap::new(),
|
||||
ret,
|
||||
fun: RefCell::new(None),
|
||||
});
|
||||
self.calls.insert(location.into(), call.clone());
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
let fields = once((method, call)).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
|
@ -398,7 +398,7 @@ impl<'a> Inferencer<'a> {
|
|||
.map(|v| fold::fold_keyword(self, v))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
let call = Arc::new(Call {
|
||||
let call = self.unifier.add_call(Call {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
|
@ -407,7 +407,7 @@ impl<'a> Inferencer<'a> {
|
|||
fun: RefCell::new(None),
|
||||
ret,
|
||||
});
|
||||
self.calls.insert(location.into(), call.clone());
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ struct TestEnvironment {
|
|||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ impl TestEnvironment {
|
|||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
top_level: TopLevelContext {
|
||||
|
@ -273,7 +273,7 @@ impl TestEnvironment {
|
|||
.cloned()
|
||||
.collect(),
|
||||
class_names,
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
|
|
|
@ -16,6 +16,9 @@ mod test;
|
|||
/// Handle for a type, implementated as a key in the unification table.
|
||||
pub type Type = UnificationKey;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub struct CallId(usize);
|
||||
|
||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||
type VarMap = Mapping<u32>;
|
||||
|
||||
|
@ -73,7 +76,7 @@ pub enum TypeEnum {
|
|||
TVirtual {
|
||||
ty: Type,
|
||||
},
|
||||
TCall(RefCell<Vec<Arc<Call>>>),
|
||||
TCall(RefCell<Vec<CallId>>),
|
||||
TFunc(FunSignature),
|
||||
}
|
||||
|
||||
|
@ -92,17 +95,18 @@ impl TypeEnum {
|
|||
}
|
||||
}
|
||||
|
||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>;
|
||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||
|
||||
pub struct Unifier {
|
||||
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||
calls: Vec<Rc<Call>>,
|
||||
var_id: u32,
|
||||
}
|
||||
|
||||
impl Unifier {
|
||||
/// Get an empty unifier
|
||||
pub fn new() -> Unifier {
|
||||
Unifier { unification_table: UnificationTable::new(), var_id: 0 }
|
||||
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
|
||||
}
|
||||
|
||||
/// Determine if the two types are the same
|
||||
|
@ -112,11 +116,19 @@ impl Unifier {
|
|||
|
||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||
let lock = unifier.lock().unwrap();
|
||||
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 }
|
||||
Unifier {
|
||||
unification_table: UnificationTable::from_send(&lock.0),
|
||||
var_id: lock.1,
|
||||
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id)))
|
||||
Arc::new(Mutex::new((
|
||||
self.unification_table.get_send(),
|
||||
self.var_id,
|
||||
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Register a type to the unifier.
|
||||
|
@ -135,6 +147,12 @@ impl Unifier {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn add_call(&mut self, call: Call) -> CallId {
|
||||
let id = CallId(self.calls.len());
|
||||
self.calls.push(Rc::new(call));
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_representative(&mut self, ty: Type) -> Type {
|
||||
self.unification_table.get_representative(ty)
|
||||
}
|
||||
|
@ -463,11 +481,11 @@ impl Unifier {
|
|||
.collect();
|
||||
// we unify every calls to the function signature.
|
||||
for c in calls.borrow().iter() {
|
||||
let Call { posargs, kwargs, ret, fun } = c.as_ref();
|
||||
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
|
||||
let instantiated = self.instantiate_fun(b, signature);
|
||||
let signature;
|
||||
let r = self.get_ty(instantiated);
|
||||
let r = r.as_ref();
|
||||
let signature;
|
||||
if let TypeEnum::TFunc(s) = &*r {
|
||||
signature = s;
|
||||
} else {
|
||||
|
@ -765,10 +783,14 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
TypeEnum::TCall(calls) => {
|
||||
let call_store = self.calls.clone();
|
||||
for t in calls
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|call| chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)))
|
||||
.map(|call| {
|
||||
let call = call_store[call.0].as_ref();
|
||||
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
|
||||
})
|
||||
.flatten()
|
||||
{
|
||||
self.occur_check(a, *t)?;
|
||||
|
|
Loading…
Reference in New Issue