hm-inference #6
|
@ -49,7 +49,7 @@ pub struct CodeGenTask {
|
||||||
pub signature: FunSignature,
|
pub signature: FunSignature,
|
||||||
pub body: Vec<Stmt<Option<Type>>>,
|
pub body: Vec<Stmt<Option<Type>>>,
|
||||||
pub unifier_index: usize,
|
pub unifier_index: usize,
|
||||||
pub resolver: Arc<dyn SymbolResolver>,
|
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_llvm_type<'ctx>(
|
fn get_llvm_type<'ctx>(
|
||||||
|
@ -108,7 +108,7 @@ pub fn gen_func<'ctx>(
|
||||||
module: Module<'ctx>,
|
module: Module<'ctx>,
|
||||||
task: CodeGenTask,
|
task: CodeGenTask,
|
||||||
top_level_ctx: Arc<TopLevelContext>,
|
top_level_ctx: Arc<TopLevelContext>,
|
||||||
) -> Module<'ctx> {
|
) -> (Builder<'ctx>, Module<'ctx>) {
|
||||||
// unwrap_or(0) is for unit tests without using rayon
|
// unwrap_or(0) is for unit tests without using rayon
|
||||||
let (mut unifier, primitives) = {
|
let (mut unifier, primitives) = {
|
||||||
let unifiers = top_level_ctx.unifiers.read();
|
let unifiers = top_level_ctx.unifiers.read();
|
||||||
|
@ -199,5 +199,7 @@ pub fn gen_func<'ctx>(
|
||||||
code_gen_context.gen_stmt(stmt);
|
code_gen_context.gen_stmt(stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
code_gen_context.module
|
let CodeGenContext { builder, module, .. } = code_gen_context;
|
||||||
|
|
||||||
|
(builder, module)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::set_primitives_magic_methods,
|
magic_methods::set_primitives_magic_methods,
|
||||||
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
||||||
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -48,7 +48,7 @@ struct TestEnvironment {
|
||||||
pub id_to_name: HashMap<usize, String>,
|
pub id_to_name: HashMap<usize, String>,
|
||||||
pub identifier_mapping: HashMap<String, Type>,
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
pub virtual_checks: Vec<(Type, Type)>,
|
pub virtual_checks: Vec<(Type, Type)>,
|
||||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
pub calls: HashMap<CodeLocation, CallId>,
|
||||||
pub top_level: TopLevelContext,
|
pub top_level: TopLevelContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ impl TestEnvironment {
|
||||||
id_to_type: identifier_mapping.clone(),
|
id_to_type: identifier_mapping.clone(),
|
||||||
id_to_def: Default::default(),
|
id_to_def: Default::default(),
|
||||||
class_names: Default::default(),
|
class_names: Default::default(),
|
||||||
}) as Arc<dyn SymbolResolver>;
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
TestEnvironment {
|
TestEnvironment {
|
||||||
unifier,
|
unifier,
|
||||||
|
@ -239,7 +239,7 @@ fn test_primitives() {
|
||||||
}
|
}
|
||||||
"}
|
"}
|
||||||
.trim();
|
.trim();
|
||||||
let ir = module.print_to_string().to_string();
|
let ir = module.1.print_to_string().to_string();
|
||||||
println!("src:\n{}", source);
|
println!("src:\n{}", source);
|
||||||
println!("IR:\n{}", ir);
|
println!("IR:\n{}", ir);
|
||||||
assert_eq!(expected, ir.trim());
|
assert_eq!(expected, ir.trim());
|
||||||
|
|
|
@ -161,19 +161,7 @@ pub fn parse_type_annotation<T>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl dyn SymbolResolver {
|
impl dyn SymbolResolver + Send + Sync {
|
||||||
pub fn parse_type_annotation<T>(
|
|
||||||
&self,
|
|
||||||
top_level: &TopLevelContext,
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
expr: &Expr<T>,
|
|
||||||
) -> Result<Type, String> {
|
|
||||||
parse_type_annotation(self, top_level, unifier, primitives, expr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl dyn SymbolResolver + Send {
|
|
||||||
pub fn parse_type_annotation<T>(
|
pub fn parse_type_annotation<T>(
|
||||||
&self,
|
&self,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
|
|
|
@ -3,8 +3,8 @@ use std::convert::{From, TryInto};
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::magic_methods::*;
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||||
|
use super::{magic_methods::*, typedef::CallId};
|
||||||
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use rustpython_parser::ast::{
|
use rustpython_parser::ast::{
|
||||||
|
@ -38,7 +38,7 @@ pub struct PrimitiveStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FunctionData {
|
pub struct FunctionData {
|
||||||
pub resolver: Arc<dyn SymbolResolver>,
|
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
|
||||||
pub return_type: Option<Type>,
|
pub return_type: Option<Type>,
|
||||||
pub bound_variables: Vec<Type>,
|
pub bound_variables: Vec<Type>,
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ pub struct Inferencer<'a> {
|
||||||
pub primitives: &'a PrimitiveStore,
|
pub primitives: &'a PrimitiveStore,
|
||||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||||
pub variable_mapping: HashMap<String, Type>,
|
pub variable_mapping: HashMap<String, Type>,
|
||||||
pub calls: &'a mut HashMap<CodeLocation, Arc<Call>>,
|
pub calls: &'a mut HashMap<CodeLocation, CallId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NaiveFolder();
|
struct NaiveFolder();
|
||||||
|
@ -190,13 +190,13 @@ impl<'a> Inferencer<'a> {
|
||||||
params: Vec<Type>,
|
params: Vec<Type>,
|
||||||
ret: Type,
|
ret: Type,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let call = Arc::new(Call {
|
let call = self.unifier.add_call(Call {
|
||||||
posargs: params,
|
posargs: params,
|
||||||
kwargs: HashMap::new(),
|
kwargs: HashMap::new(),
|
||||||
ret,
|
ret,
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
});
|
});
|
||||||
self.calls.insert(location.into(), call.clone());
|
self.calls.insert(location.into(), call);
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
let fields = once((method, call)).collect();
|
let fields = once((method, call)).collect();
|
||||||
let record = self.unifier.add_record(fields);
|
let record = self.unifier.add_record(fields);
|
||||||
|
@ -398,7 +398,7 @@ impl<'a> Inferencer<'a> {
|
||||||
.map(|v| fold::fold_keyword(self, v))
|
.map(|v| fold::fold_keyword(self, v))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let ret = self.unifier.get_fresh_var().0;
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
let call = Arc::new(Call {
|
let call = self.unifier.add_call(Call {
|
||||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||||
kwargs: keywords
|
kwargs: keywords
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -407,7 +407,7 @@ impl<'a> Inferencer<'a> {
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret,
|
ret,
|
||||||
});
|
});
|
||||||
self.calls.insert(location.into(), call.clone());
|
self.calls.insert(location.into(), call);
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ struct TestEnvironment {
|
||||||
pub id_to_name: HashMap<usize, String>,
|
pub id_to_name: HashMap<usize, String>,
|
||||||
pub identifier_mapping: HashMap<String, Type>,
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
pub virtual_checks: Vec<(Type, Type)>,
|
pub virtual_checks: Vec<(Type, Type)>,
|
||||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
pub calls: HashMap<CodeLocation, CallId>,
|
||||||
pub top_level: TopLevelContext,
|
pub top_level: TopLevelContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ impl TestEnvironment {
|
||||||
id_to_type: identifier_mapping.clone(),
|
id_to_type: identifier_mapping.clone(),
|
||||||
id_to_def: Default::default(),
|
id_to_def: Default::default(),
|
||||||
class_names: Default::default(),
|
class_names: Default::default(),
|
||||||
}) as Arc<dyn SymbolResolver>;
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
TestEnvironment {
|
TestEnvironment {
|
||||||
top_level: TopLevelContext {
|
top_level: TopLevelContext {
|
||||||
|
@ -273,7 +273,7 @@ impl TestEnvironment {
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect(),
|
.collect(),
|
||||||
class_names,
|
class_names,
|
||||||
}) as Arc<dyn SymbolResolver>;
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
TestEnvironment {
|
TestEnvironment {
|
||||||
unifier,
|
unifier,
|
||||||
|
|
|
@ -16,6 +16,9 @@ mod test;
|
||||||
/// Handle for a type, implementated as a key in the unification table.
|
/// Handle for a type, implementated as a key in the unification table.
|
||||||
pub type Type = UnificationKey;
|
pub type Type = UnificationKey;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct CallId(usize);
|
||||||
|
|
||||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||||
type VarMap = Mapping<u32>;
|
type VarMap = Mapping<u32>;
|
||||||
|
|
||||||
|
@ -73,7 +76,7 @@ pub enum TypeEnum {
|
||||||
TVirtual {
|
TVirtual {
|
||||||
ty: Type,
|
ty: Type,
|
||||||
},
|
},
|
||||||
TCall(RefCell<Vec<Arc<Call>>>),
|
TCall(RefCell<Vec<CallId>>),
|
||||||
TFunc(FunSignature),
|
TFunc(FunSignature),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,17 +95,18 @@ impl TypeEnum {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>;
|
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||||
|
|
||||||
pub struct Unifier {
|
pub struct Unifier {
|
||||||
unification_table: UnificationTable<Rc<TypeEnum>>,
|
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||||
|
calls: Vec<Rc<Call>>,
|
||||||
var_id: u32,
|
var_id: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Unifier {
|
impl Unifier {
|
||||||
/// Get an empty unifier
|
/// Get an empty unifier
|
||||||
pub fn new() -> Unifier {
|
pub fn new() -> Unifier {
|
||||||
Unifier { unification_table: UnificationTable::new(), var_id: 0 }
|
Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determine if the two types are the same
|
/// Determine if the two types are the same
|
||||||
|
@ -112,11 +116,19 @@ impl Unifier {
|
||||||
|
|
||||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||||
let lock = unifier.lock().unwrap();
|
let lock = unifier.lock().unwrap();
|
||||||
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 }
|
Unifier {
|
||||||
|
unification_table: UnificationTable::from_send(&lock.0),
|
||||||
|
var_id: lock.1,
|
||||||
|
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||||
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id)))
|
Arc::new(Mutex::new((
|
||||||
|
self.unification_table.get_send(),
|
||||||
|
self.var_id,
|
||||||
|
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a type to the unifier.
|
/// Register a type to the unifier.
|
||||||
|
@ -135,6 +147,12 @@ impl Unifier {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_call(&mut self, call: Call) -> CallId {
|
||||||
|
let id = CallId(self.calls.len());
|
||||||
|
self.calls.push(Rc::new(call));
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_representative(&mut self, ty: Type) -> Type {
|
pub fn get_representative(&mut self, ty: Type) -> Type {
|
||||||
self.unification_table.get_representative(ty)
|
self.unification_table.get_representative(ty)
|
||||||
}
|
}
|
||||||
|
@ -463,11 +481,11 @@ impl Unifier {
|
||||||
.collect();
|
.collect();
|
||||||
// we unify every calls to the function signature.
|
// we unify every calls to the function signature.
|
||||||
for c in calls.borrow().iter() {
|
for c in calls.borrow().iter() {
|
||||||
let Call { posargs, kwargs, ret, fun } = c.as_ref();
|
let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone();
|
||||||
let instantiated = self.instantiate_fun(b, signature);
|
let instantiated = self.instantiate_fun(b, signature);
|
||||||
let signature;
|
|
||||||
let r = self.get_ty(instantiated);
|
let r = self.get_ty(instantiated);
|
||||||
let r = r.as_ref();
|
let r = r.as_ref();
|
||||||
|
let signature;
|
||||||
if let TypeEnum::TFunc(s) = &*r {
|
if let TypeEnum::TFunc(s) = &*r {
|
||||||
signature = s;
|
signature = s;
|
||||||
} else {
|
} else {
|
||||||
|
@ -765,10 +783,14 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TCall(calls) => {
|
TypeEnum::TCall(calls) => {
|
||||||
|
let call_store = self.calls.clone();
|
||||||
for t in calls
|
for t in calls
|
||||||
.borrow()
|
.borrow()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|call| chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)))
|
.map(|call| {
|
||||||
|
let call = call_store[call.0].as_ref();
|
||||||
|
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
|
||||||
|
})
|
||||||
.flatten()
|
.flatten()
|
||||||
{
|
{
|
||||||
self.occur_check(a, *t)?;
|
self.occur_check(a, *t)?;
|
||||||
|
|
Loading…
Reference in New Issue