1
0
forked from M-Labs/nac3

applied rustfmt and clippy auto fix

This commit is contained in:
pca006132 2022-02-21 18:27:46 +08:00
parent d9cb506f6a
commit f97f93d92c
27 changed files with 2038 additions and 1767 deletions

View File

@ -130,7 +130,8 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// the LLVM Context. // the LLVM Context.
// The name is guaranteed to be unique as users cannot use this as variable // The name is guaranteed to be unique as users cannot use this as variable
// name. // name.
self.start = old_start.clone().map_or_else(|| { self.start = old_start.clone().map_or_else(
|| {
let start = format!("with-{}-start", self.name_counter).into(); let start = format!("with-{}-start", self.name_counter).into();
let start_expr = Located { let start_expr = Located {
// location does not matter at this point // location does not matter at this point
@ -141,7 +142,9 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let start = self.gen_store_target(ctx, &start_expr)?; let start = self.gen_store_target(ctx, &start_expr)?;
ctx.builder.build_store(start, now); ctx.builder.build_store(start, now);
Ok(Some(start_expr)) as Result<_, String> Ok(Some(start_expr)) as Result<_, String>
}, |v| Ok(Some(v)))?; },
|v| Ok(Some(v)),
)?;
let end = format!("with-{}-end", self.name_counter).into(); let end = format!("with-{}-end", self.name_counter).into();
let end_expr = Located { let end_expr = Located {
// location does not matter at this point // location does not matter at this point
@ -179,8 +182,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
} }
// inside a parallel block, should update the outer max now_mu // inside a parallel block, should update the outer max now_mu
if let Some(old_end) = &old_end { if let Some(old_end) = &old_end {
let outer_end_val = let outer_end_val = self
self.gen_expr(ctx, old_end)?.unwrap().to_basic_value_enum(ctx, self); .gen_expr(ctx, old_end)?
.unwrap()
.to_basic_value_enum(ctx, self);
let smax = let smax =
ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
let i64 = ctx.ctx.i64_type(); let i64 = ctx.ctx.i64_type();
@ -226,7 +231,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
} }
} }
fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec<u8>) -> Result<(), String> { fn gen_rpc_tag<'ctx, 'a>(
ctx: &mut CodeGenContext<'ctx, 'a>,
ty: Type,
buffer: &mut Vec<u8>,
) -> Result<(), String> {
use nac3core::typecheck::typedef::TypeEnum::*; use nac3core::typecheck::typedef::TypeEnum::*;
let int32 = ctx.primitives.int32; let int32 = ctx.primitives.int32;
@ -283,7 +292,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
if obj.is_some() { if obj.is_some() {
@ -433,7 +442,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
return Ok(None) return Ok(None);
} }
let prehead_bb = ctx.builder.get_insert_block().unwrap(); let prehead_bb = ctx.builder.get_insert_block().unwrap();

View File

@ -10,7 +10,7 @@ use inkwell::{
targets::*, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::typedef::{Unifier, TypeEnum}; use nac3core::typecheck::typedef::{TypeEnum, Unifier};
use nac3parser::{ use nac3parser::{
ast::{self, ExprKind, Stmt, StmtKind, StrRef}, ast::{self, ExprKind, Stmt, StmtKind, StrRef},
parser::{self, parse_program}, parser::{self, parse_program},
@ -21,8 +21,8 @@ use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use nac3core::{ use nac3core::{
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
codegen::irrt::load_irrt, codegen::irrt::load_irrt,
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
composer::{ComposerConfig, TopLevelComposer}, composer::{ComposerConfig, TopLevelComposer},
@ -96,10 +96,7 @@ impl Nac3 {
) -> PyResult<()> { ) -> PyResult<()> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let module: &PyAny = module.extract(py)?; let module: &PyAny = module.extract(py)?;
Ok(( Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
module.getattr("__name__")?.extract()?,
module.getattr("__file__")?.extract()?,
))
})?; })?;
let source = fs::read_to_string(&source_file).map_err(|e| { let source = fs::read_to_string(&source_file).map_err(|e| {
@ -111,10 +108,7 @@ impl Nac3 {
for mut stmt in parser_result.into_iter() { for mut stmt in parser_result.into_iter() {
let include = match stmt.node { let include = match stmt.node {
ast::StmtKind::ClassDef { ast::StmtKind::ClassDef {
ref decorator_list, ref decorator_list, ref mut body, ref mut bases, ..
ref mut body,
ref mut bases,
..
} => { } => {
let nac3_class = decorator_list.iter().any(|decorator| { let nac3_class = decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node { if let ast::ExprKind::Name { id, .. } = decorator.node {
@ -146,10 +140,7 @@ impl Nac3 {
.unwrap() .unwrap()
}); });
body.retain(|stmt| { body.retain(|stmt| {
if let ast::StmtKind::FunctionDef { if let ast::StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
ref decorator_list, ..
} = stmt.node
{
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node { if let ast::ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" id.to_string() == "kernel"
@ -165,22 +156,21 @@ impl Nac3 {
}); });
true true
} }
ast::StmtKind::FunctionDef { ast::StmtKind::FunctionDef { ref decorator_list, .. } => {
ref decorator_list, .. decorator_list.iter().any(|decorator| {
} => decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node { if let ast::ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string(); let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" || id == "rpc" id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else { } else {
false false
} }
}), })
}
_ => false, _ => false,
}; };
if include { if include {
self.top_levels self.top_levels.push((stmt, module_name.clone(), module.clone()));
.push((stmt, module_name.clone(), module.clone()));
} }
} }
Ok(()) Ok(())
@ -197,7 +187,7 @@ impl Nac3 {
let base_ty = let base_ty =
match resolver.get_symbol_type(unifier, top_level_defs, primitives, "base".into()) { match resolver.get_symbol_type(unifier, top_level_defs, primitives, "base".into()) {
Ok(ty) => ty, Ok(ty) => ty,
Err(e) => return Some(format!("type error inside object launching kernel: {}", e)) Err(e) => return Some(format!("type error inside object launching kernel: {}", e)),
}; };
let fun_ty = if method_name.is_empty() { let fun_ty = if method_name.is_empty() {
@ -205,12 +195,15 @@ impl Nac3 {
} else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) { } else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) {
match fields.get(&(*method_name).into()) { match fields.get(&(*method_name).into()) {
Some(t) => t.0, Some(t) => t.0,
None => return Some( None => {
format!("object launching kernel does not have method `{}`", method_name) return Some(format!(
) "object launching kernel does not have method `{}`",
method_name
))
}
} }
} else { } else {
return Some("cannot launch kernel by calling a non-callable".into()) return Some("cannot launch kernel by calling a non-callable".into());
}; };
if let TypeEnum::TFunc(FunSignature { args, .. }) = &*unifier.get_ty(fun_ty) { if let TypeEnum::TFunc(FunSignature { args, .. }) = &*unifier.get_ty(fun_ty) {
@ -219,35 +212,43 @@ impl Nac3 {
"launching kernel function with too many arguments (expect {}, found {})", "launching kernel function with too many arguments (expect {}, found {})",
args.len(), args.len(),
arg_names.len(), arg_names.len(),
)) ));
} }
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() { for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) { let in_name = match arg_names.get(i) {
Some(n) => n, Some(n) => n,
None if default_value.is_none() => return Some(format!( None if default_value.is_none() => {
"argument `{}` not provided when launching kernel function", name return Some(format!(
)), "argument `{}` not provided when launching kernel function",
name
))
}
_ => break, _ => break,
}; };
let in_ty = match resolver.get_symbol_type( let in_ty = match resolver.get_symbol_type(
unifier, unifier,
top_level_defs, top_level_defs,
primitives, primitives,
in_name.clone().into() in_name.clone().into(),
) { ) {
Ok(t) => t, Ok(t) => t,
Err(e) => return Some(format!( Err(e) => {
"type error ({}) at parameter #{} when calling kernel function", e, i return Some(format!(
"type error ({}) at parameter #{} when calling kernel function",
e, i
)) ))
}
}; };
if let Err(e) = unifier.unify(in_ty, *ty) { if let Err(e) = unifier.unify(in_ty, *ty) {
return Some(format!( return Some(format!(
"type error ({}) at parameter #{} when calling kernel function", e.to_display(unifier).to_string(), i "type error ({}) at parameter #{} when calling kernel function",
e.to_display(unifier).to_string(),
i
)); ));
} }
} }
} else { } else {
return Some("cannot launch kernel by calling a non-callable".into()) return Some("cannot launch kernel by calling a non-callable".into());
} }
None None
} }
@ -274,11 +275,7 @@ impl Nac3 {
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
FunSignature { FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() },
args: vec![],
ret: primitive.int64,
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx))) Ok(Some(time_fns.emit_now_mu(ctx)))
}))), }))),
@ -320,10 +317,7 @@ impl Nac3 {
]; ];
let (_, builtins_def, builtins_ty) = TopLevelComposer::new( let (_, builtins_def, builtins_ty) = TopLevelComposer::new(
builtins.clone(), builtins.clone(),
ComposerConfig { ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant",
},
); );
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
@ -355,46 +349,22 @@ impl Nac3 {
.extract() .extract()
.unwrap(), .unwrap(),
), ),
none: id_fn none: id_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap().extract().unwrap(),
.call1((builtins_mod.getattr("None").unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: id_fn typevar: id_fn
.call1((typing_mod.getattr("TypeVar").unwrap(),)) .call1((typing_mod.getattr("TypeVar").unwrap(),))
.unwrap() .unwrap()
.extract() .extract()
.unwrap(), .unwrap(),
int: id_fn int: id_fn.call1((builtins_mod.getattr("int").unwrap(),)).unwrap().extract().unwrap(),
.call1((builtins_mod.getattr("int").unwrap(),)) int32: id_fn.call1((numpy_mod.getattr("int32").unwrap(),)).unwrap().extract().unwrap(),
.unwrap() int64: id_fn.call1((numpy_mod.getattr("int64").unwrap(),)).unwrap().extract().unwrap(),
.extract() bool: id_fn.call1((builtins_mod.getattr("bool").unwrap(),)).unwrap().extract().unwrap(),
.unwrap(),
int32: id_fn
.call1((numpy_mod.getattr("int32").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int64: id_fn
.call1((numpy_mod.getattr("int64").unwrap(),))
.unwrap()
.extract()
.unwrap(),
bool: id_fn
.call1((builtins_mod.getattr("bool").unwrap(),))
.unwrap()
.extract()
.unwrap(),
float: id_fn float: id_fn
.call1((builtins_mod.getattr("float").unwrap(),)) .call1((builtins_mod.getattr("float").unwrap(),))
.unwrap() .unwrap()
.extract() .extract()
.unwrap(), .unwrap(),
list: id_fn list: id_fn.call1((builtins_mod.getattr("list").unwrap(),)).unwrap().extract().unwrap(),
.call1((builtins_mod.getattr("list").unwrap(),))
.unwrap()
.extract()
.unwrap(),
tuple: id_fn tuple: id_fn
.call1((builtins_mod.getattr("tuple").unwrap(),)) .call1((builtins_mod.getattr("tuple").unwrap(),))
.unwrap() .unwrap()
@ -408,11 +378,7 @@ impl Nac3 {
}; };
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
fs::write( fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
working_directory.path().join("kernel.ld"),
include_bytes!("kernel.ld"),
)
.unwrap();
Ok(Nac3 { Ok(Nac3 {
isa, isa,
@ -425,7 +391,7 @@ impl Nac3 {
top_levels: Default::default(), top_levels: Default::default(),
pyid_to_def: Default::default(), pyid_to_def: Default::default(),
working_directory, working_directory,
string_store: Default::default() string_store: Default::default(),
}) })
} }
@ -465,20 +431,17 @@ impl Nac3 {
embedding_map: &PyAny, embedding_map: &PyAny,
py: Python, py: Python,
) -> PyResult<()> { ) -> PyResult<()> {
let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig { let (mut composer, _, _) = TopLevelComposer::new(
kernel_ann: Some("Kernel"), self.builtins.clone(),
kernel_invariant_ann: "KernelInvariant" ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
}); );
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?; let typings = PyModule::import(py, "typing")?;
let id_fn = builtins.getattr("id")?; let id_fn = builtins.getattr("id")?;
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py);
.getattr("store_function")
.unwrap()
.to_object(py);
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py), id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py),
@ -486,7 +449,7 @@ impl Nac3 {
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
store_obj, store_obj,
store_str store_str,
}; };
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
@ -497,10 +460,8 @@ impl Nac3 {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
let helper = helper.clone(); let helper = helper.clone();
let (name_to_pyid, resolver) = module_to_resolver_cache let (name_to_pyid, resolver) =
.get(&module_id) module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
.cloned()
.unwrap_or_else(|| {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members: &PyDict = let members: &PyDict =
py_module.getattr("__dict__").unwrap().cast_as().unwrap(); py_module.getattr("__dict__").unwrap().cast_as().unwrap();
@ -535,7 +496,10 @@ impl Nac3 {
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.map_err(|e| { .map_err(|e| {
exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure\n----------\n{}", e)) exceptions::PyRuntimeError::new_err(format!(
"nac3 compilation failure\n----------\n{}",
e
))
})?; })?;
match &stmt.node { match &stmt.node {
@ -583,16 +547,10 @@ impl Nac3 {
let synthesized = if method_name.is_empty() { let synthesized = if method_name.is_empty() {
format!("def __modinit__():\n base({})", arg_names.join(", ")) format!("def __modinit__():\n base({})", arg_names.join(", "))
} else { } else {
format!( format!("def __modinit__():\n base.{}({})", method_name, arg_names.join(", "))
"def __modinit__():\n base.{}({})",
method_name,
arg_names.join(", ")
)
}; };
let mut synthesized = parse_program( let mut synthesized =
&synthesized, parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap();
"__nac3_synthesized_modinit__".to_string().into(),
).unwrap();
let resolver = Arc::new(Resolver(Arc::new(InnerResolver { let resolver = Arc::new(Resolver(Arc::new(InnerResolver {
id_to_type: self.builtins_ty.clone().into(), id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(), id_to_def: self.builtins_def.clone().into(),
@ -610,34 +568,24 @@ impl Nac3 {
string_store: self.string_store.clone(), string_store: self.string_store.clone(),
}))) as Arc<dyn SymbolResolver + Send + Sync>; }))) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level( .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into())
synthesized.pop().unwrap(),
Some(resolver.clone()),
"".into(),
)
.unwrap(); .unwrap();
let signature = FunSignature { let signature =
args: vec![], FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() };
ret: self.primitive.none,
vars: HashMap::new(),
};
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = store.from_signature( let signature =
&mut composer.unifier, store.from_signature(&mut composer.unifier, &self.primitive, &signature, &mut cache);
&self.primitive,
&signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) { if let Err(e) = composer.start_analysis(true) {
// report error of __modinit__ separately // report error of __modinit__ separately
if !e.contains("__nac3_synthesized_modinit__") { if !e.contains("__nac3_synthesized_modinit__") {
return Err(exceptions::PyRuntimeError::new_err( return Err(exceptions::PyRuntimeError::new_err(format!(
format!("nac3 compilation failure: \n----------\n{}", e) "nac3 compilation failure: \n----------\n{}",
)); e
)));
} else { } else {
let msg = Self::report_modinit( let msg = Self::report_modinit(
&arg_names, &arg_names,
@ -645,7 +593,7 @@ impl Nac3 {
resolver.clone(), resolver.clone(),
&composer.extract_def_list(), &composer.extract_def_list(),
&mut composer.unifier, &mut composer.unifier,
&self.primitive &self.primitive,
); );
return Err(exceptions::PyRuntimeError::new_err(msg.unwrap())); return Err(exceptions::PyRuntimeError::new_err(msg.unwrap()));
} }
@ -658,9 +606,7 @@ impl Nac3 {
for (class_data, id) in rpc_ids.iter() { for (class_data, id) in rpc_ids.iter() {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { TopLevelDef::Function { codegen_callback, .. } => {
codegen_callback, ..
} => {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen.clone());
} }
TopLevelDef::Class { methods, .. } => { TopLevelDef::Class { methods, .. } => {
@ -669,9 +615,8 @@ impl Nac3 {
if name != method_name { if name != method_name {
continue; continue;
} }
if let TopLevelDef::Function { if let TopLevelDef::Function { codegen_callback, .. } =
codegen_callback, .. &mut *defs[id.0].write()
} = &mut *defs[id.0].write()
{ {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen.clone());
store_fun store_fun
@ -693,11 +638,8 @@ impl Nac3 {
let instance = { let instance = {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write(); let mut definition = defs[def_id.0].write();
if let TopLevelDef::Function { if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
instance_to_stmt, &mut *definition
instance_to_symbol,
..
} = &mut *definition
{ {
instance_to_symbol.insert("".to_string(), "__modinit__".into()); instance_to_symbol.insert("".to_string(), "__modinit__".into());
instance_to_stmt[""].clone() instance_to_stmt[""].clone()
@ -733,13 +675,7 @@ impl Nac3 {
let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
.iter() .iter()
.map(|s| { .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
Box::new(ArtiqCodeGenerator::new(
s.to_string(),
size_t,
self.time_fns,
))
})
.collect(); .collect();
py.allow_threads(|| { py.allow_threads(|| {
@ -784,14 +720,10 @@ impl Nac3 {
TargetMachine::get_default_triple(), TargetMachine::get_default_triple(),
TargetMachine::get_host_cpu_features().to_string(), TargetMachine::get_host_cpu_features().to_string(),
), ),
Isa::RiscV32G => ( Isa::RiscV32G => {
TargetTriple::create("riscv32-unknown-linux"), (TargetTriple::create("riscv32-unknown-linux"), "+a,+m,+f,+d".to_string())
"+a,+m,+f,+d".to_string(), }
), Isa::RiscV32IMA => (TargetTriple::create("riscv32-unknown-linux"), "+a,+m".to_string()),
Isa::RiscV32IMA => (
TargetTriple::create("riscv32-unknown-linux"),
"+a,+m".to_string(),
),
Isa::CortexA9 => ( Isa::CortexA9 => (
TargetTriple::create("armv7-unknown-linux-gnueabihf"), TargetTriple::create("armv7-unknown-linux-gnueabihf"),
"+dsp,+fp16,+neon,+vfp3".to_string(), "+dsp,+fp16,+neon,+vfp3".to_string(),
@ -819,28 +751,18 @@ impl Nac3 {
"-x".to_string(), "-x".to_string(),
"-o".to_string(), "-o".to_string(),
filename.to_string(), filename.to_string(),
working_directory working_directory.join("module.o").to_string_lossy().to_string(),
.join("module.o")
.to_string_lossy()
.to_string(),
]; ];
if isa != Isa::Host { if isa != Isa::Host {
linker_args.push( linker_args.push(
"-T".to_string() "-T".to_string()
+ self + self.working_directory.path().join("kernel.ld").to_str().unwrap(),
.working_directory
.path()
.join("kernel.ld")
.to_str()
.unwrap(),
); );
} }
if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() { if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() {
if !linker_status.success() { if !linker_status.success() {
return Err(exceptions::PyRuntimeError::new_err( return Err(exceptions::PyRuntimeError::new_err("failed to start linker"));
"failed to start linker",
));
} }
} else { } else {
return Err(exceptions::PyRuntimeError::new_err( return Err(exceptions::PyRuntimeError::new_err(

View File

@ -83,17 +83,21 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> { Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let global = let global = ctx.module.add_global(
ctx.module struct_type,
.add_global(struct_type, None, format!("{}_const", self.id).as_str()); None,
format!("{}_const", self.id).as_str(),
);
global.set_constant(true); global.set_constant(true);
global.set_initializer(&ctx.ctx.const_struct( global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(id as u64, false).into()], &[ctx.ctx.i32_type().const_int(id as u64, false).into()],
false, false,
)); ));
let global2 = let global2 = ctx.module.add_global(
ctx.module struct_type.ptr_type(AddressSpace::Generic),
.add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str()); None,
format!("{}_const2", self.id).as_str(),
);
global2.set_initializer(&global.as_pointer_value()); global2.set_initializer(&global.as_pointer_value());
Ok(global2.as_pointer_value().into()) Ok(global2.as_pointer_value().into())
}) })
@ -160,10 +164,7 @@ impl StaticValue for PythonValue {
let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?; let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?;
Some((id, obj)) Some((id, obj))
}; };
self.resolver self.resolver.field_to_val.write().insert((self.id, name), result.clone());
.field_to_val
.write()
.insert((self.id, name), result.clone());
Ok(result) Ok(result)
}) })
.unwrap() .unwrap()
@ -191,24 +192,27 @@ impl InnerResolver {
) -> PyResult<Result<Type, String>> { ) -> PyResult<Result<Type, String>> {
let mut ty = match self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)? { let mut ty = match self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)? {
Ok(t) => t, Ok(t) => t,
Err(e) => return Ok(Err(format!( Err(e) => return Ok(Err(format!("type error ({}) at element #0 of the list", e))),
"type error ({}) at element #0 of the list", e
))),
}; };
for i in 1..len { for i in 1..len {
let b = match list let b = match list
.get_item(i) .get_item(i)
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))?? { .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))??
{
Ok(t) => t, Ok(t) => t,
Err(e) => return Ok(Err(format!( Err(e) => {
"type error ({}) at element #{} of the list", e, i return Ok(Err(format!("type error ({}) at element #{} of the list", e, i)))
))), }
}; };
ty = match unifier.unify(ty, b) { ty = match unifier.unify(ty, b) {
Ok(_) => ty, Ok(_) => ty,
Err(e) => return Ok(Err(format!( Err(e) => {
"inhomogeneous type ({}) at element #{} of the list", e.to_display(unifier).to_string(), i return Ok(Err(format!(
"inhomogeneous type ({}) at element #{} of the list",
e.to_display(unifier).to_string(),
i
))) )))
}
}; };
} }
Ok(Ok(ty)) Ok(Ok(ty))
@ -227,11 +231,8 @@ impl InnerResolver {
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> PyResult<Result<(Type, bool), String>> { ) -> PyResult<Result<(Type, bool), String>> {
let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?;
let ty_ty_id: u64 = self let ty_ty_id: u64 =
.helper self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?;
.id_fn
.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?
.extract(py)?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
Ok(Ok((primitives.int32, true))) Ok(Ok((primitives.int32, true)))
@ -243,7 +244,7 @@ impl InnerResolver {
Ok(Ok((primitives.float, true))) Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.exception { } else if ty_id == self.primitive_ids.exception {
Ok(Ok((primitives.exception, true))) Ok(Ok((primitives.exception, true)))
}else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_dummy_var().0; let var = unifier.get_dummy_var().0;
let list = unifier.add_ty(TypeEnum::TList { ty: var }); let list = unifier.add_ty(TypeEnum::TList { ty: var });
@ -253,14 +254,7 @@ impl InnerResolver {
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
let def = defs[def_id.0].read(); let def = defs[def_id.0].read();
if let TopLevelDef::Class { if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
object_id,
type_vars,
fields,
methods,
..
} = &*def
{
// do not handle type var param and concrete check here, and no subst // do not handle type var param and concrete check here, and no subst
Ok(Ok({ Ok(Ok({
let ty = TypeEnum::TObj { let ty = TypeEnum::TObj {
@ -320,7 +314,8 @@ impl InnerResolver {
} }
result result
}; };
let res = unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; let res =
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0;
Ok(Ok((res, true))) Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 } else if ty_ty_id == self.primitive_ids.generic_alias.0
|| ty_ty_id == self.primitive_ids.generic_alias.1 || ty_ty_id == self.primitive_ids.generic_alias.1
@ -352,7 +347,7 @@ impl InnerResolver {
}; };
if !unifier.is_concrete(ty.0, &[]) && !ty.1 { if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
return Ok(Err( return Ok(Err(
"type list should take concrete parameters in typevar range".into() "type list should take concrete parameters in typevar range".into(),
)); ));
} }
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
@ -417,10 +412,7 @@ impl InnerResolver {
.map(|((id, _), ty)| (*id, *ty)) .map(|((id, _), ty)| (*id, *ty))
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
}; };
Ok(Ok(( Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true)))
unifier.subst(origin_ty, &subst).unwrap_or(origin_ty),
true,
)))
} }
TypeEnum::TVirtual { .. } => { TypeEnum::TVirtual { .. } => {
if args.len() == 1 { if args.len() == 1 {
@ -452,17 +444,19 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.virtual_id { } else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(( Ok(Ok((
{ {
let ty = TypeEnum::TVirtual { let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().0 };
ty: unifier.get_dummy_var().0,
};
unifier.add_ty(ty) unifier.add_ty(ty)
}, },
false, false,
))) )))
} else { } else {
let str_fn = pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!("{} is not supported in nac3 (did you forgot to put @nac3 annotation?)", str_repr))) Ok(Err(format!(
"{} is not supported in nac3 (did you forgot to put @nac3 annotation?)",
str_repr
)))
} }
} }
@ -483,13 +477,8 @@ impl InnerResolver {
self.primitive_ids.generic_alias.0, self.primitive_ids.generic_alias.0,
self.primitive_ids.generic_alias.1, self.primitive_ids.generic_alias.1,
] ]
.contains( .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?)
&self {
.helper
.id_fn
.call1(py, (ty.clone(),))?
.extract::<u64>(py)?,
) {
obj obj
} else { } else {
ty.as_ref(py) ty.as_ref(py)
@ -518,9 +507,12 @@ impl InnerResolver {
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty { match actual_ty {
Ok(t) => match unifier.unify(*ty, t) { Ok(t) => match unifier.unify(*ty, t) {
Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList{ ty: *ty }))), Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList { ty: *ty }))),
Err(e) => Ok(Err(format!("type error ({}) for the list", e.to_display(unifier).to_string()))), Err(e) => Ok(Err(format!(
} "type error ({}) for the list",
e.to_display(unifier).to_string()
))),
},
Err(e) => Ok(Err(e)), Err(e) => Ok(Err(e)),
} }
} }
@ -553,18 +545,23 @@ impl InnerResolver {
continue; continue;
} else { } else {
let field_data = obj.getattr(&name)?; let field_data = obj.getattr(&name)?;
let ty = match self let ty =
.get_obj_type(py, field_data, unifier, defs, primitives)? { match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t, Ok(t) => t,
Err(e) => return Ok(Err(format!( Err(e) => {
"error when getting type of field `{}` ({})", name, e return Ok(Err(format!(
))), "error when getting type of field `{}` ({})",
name, e
)))
}
}; };
let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0);
if let Err(e) = unifier.unify(ty, field_ty) { if let Err(e) = unifier.unify(ty, field_ty) {
// field type mismatch // field type mismatch
return Ok(Err(format!( return Ok(Err(format!(
"error when getting type of field `{}` ({})", name, e.to_display(unifier).to_string() "error when getting type of field `{}` ({})",
name,
e.to_display(unifier).to_string()
))); )));
} }
} }
@ -575,11 +572,7 @@ impl InnerResolver {
return Ok(Err("object is not of concrete type".into())); return Ok(Err("object is not of concrete type".into()));
} }
} }
return Ok(Ok( return Ok(Ok(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty)));
unifier
.subst(extracted_ty, &var_map)
.unwrap_or(extracted_ty),
));
} }
_ => Ok(Ok(extracted_ty)), _ => Ok(Ok(extracted_ty)),
}; };
@ -592,37 +585,24 @@ impl InnerResolver {
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> PyResult<Option<BasicValueEnum<'ctx>>> { ) -> PyResult<Option<BasicValueEnum<'ctx>>> {
let ty_id: u64 = self let ty_id: u64 =
.helper self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?;
.id_fn
.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?
.extract(py)?;
let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
let val: i32 = obj.extract()?; let val: i32 = obj.extract()?;
self.id_to_primitive self.id_to_primitive.write().insert(id, PrimitiveValue::I32(val));
.write()
.insert(id, PrimitiveValue::I32(val));
Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into()))
} else if ty_id == self.primitive_ids.int64 { } else if ty_id == self.primitive_ids.int64 {
let val: i64 = obj.extract()?; let val: i64 = obj.extract()?;
self.id_to_primitive self.id_to_primitive.write().insert(id, PrimitiveValue::I64(val));
.write()
.insert(id, PrimitiveValue::I64(val));
Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into()))
} else if ty_id == self.primitive_ids.bool { } else if ty_id == self.primitive_ids.bool {
let val: bool = obj.extract()?; let val: bool = obj.extract()?;
self.id_to_primitive self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
.write() Ok(Some(ctx.ctx.bool_type().const_int(val as u64, false).into()))
.insert(id, PrimitiveValue::Bool(val));
Ok(Some(
ctx.ctx.bool_type().const_int(val as u64, false).into(),
))
} else if ty_id == self.primitive_ids.float { } else if ty_id == self.primitive_ids.float {
let val: f64 = obj.extract()?; let val: f64 = obj.extract()?;
self.id_to_primitive self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
.write()
.insert(id, PrimitiveValue::F64(val));
Ok(Some(ctx.ctx.f64_type().const_float(val).into())) Ok(Some(ctx.ctx.f64_type().const_float(val).into()))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
let id_str = id.to_string(); let id_str = id.to_string();
@ -647,16 +627,14 @@ impl InnerResolver {
}; };
let ty = ctx.get_llvm_type(generator, ty); let ty = ctx.get_llvm_type(generator, ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let arr_ty = ctx.ctx.struct_type( let arr_ty = ctx
&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], .ctx
false, .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false);
);
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str)
.add_global(arr_ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
@ -666,8 +644,7 @@ impl InnerResolver {
let arr: Result<Option<Vec<_>>, _> = (0..len) let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| { .map(|i| {
obj.get_item(i) obj.get_item(i).and_then(|elem| self.get_obj_value(py, elem, ctx, generator))
.and_then(|elem| self.get_obj_value(py, elem, ctx, generator))
}) })
.collect(); .collect();
let arr = arr?.unwrap(); let arr = arr?.unwrap();
@ -678,34 +655,19 @@ impl InnerResolver {
&(id_str.clone() + "_"), &(id_str.clone() + "_"),
); );
let arr: BasicValueEnum = if ty.is_int_type() { let arr: BasicValueEnum = if ty.is_int_type() {
let arr: Vec<_> = arr let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_int_value).collect();
.into_iter()
.map(BasicValueEnum::into_int_value)
.collect();
ty.into_int_type().const_array(&arr) ty.into_int_type().const_array(&arr)
} else if ty.is_float_type() { } else if ty.is_float_type() {
let arr: Vec<_> = arr let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_float_value).collect();
.into_iter()
.map(BasicValueEnum::into_float_value)
.collect();
ty.into_float_type().const_array(&arr) ty.into_float_type().const_array(&arr)
} else if ty.is_array_type() { } else if ty.is_array_type() {
let arr: Vec<_> = arr let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_array_value).collect();
.into_iter()
.map(BasicValueEnum::into_array_value)
.collect();
ty.into_array_type().const_array(&arr) ty.into_array_type().const_array(&arr)
} else if ty.is_struct_type() { } else if ty.is_struct_type() {
let arr: Vec<_> = arr let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_struct_value).collect();
.into_iter()
.map(BasicValueEnum::into_struct_value)
.collect();
ty.into_struct_type().const_array(&arr) ty.into_struct_type().const_array(&arr)
} else if ty.is_pointer_type() { } else if ty.is_pointer_type() {
let arr: Vec<_> = arr let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_pointer_value).collect();
.into_iter()
.map(BasicValueEnum::into_pointer_value)
.collect();
ty.into_pointer_type().const_array(&arr) ty.into_pointer_type().const_array(&arr)
} else { } else {
unreachable!() unreachable!()
@ -714,16 +676,11 @@ impl InnerResolver {
arr_global.set_initializer(&arr); arr_global.set_initializer(&arr);
let val = arr_ty.const_named_struct(&[ let val = arr_ty.const_named_struct(&[
arr_global arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(),
.as_pointer_value()
.const_cast(ty.ptr_type(AddressSpace::Generic))
.into(),
size_t.const_int(len as u64, false).into(), size_t.const_int(len as u64, false).into(),
]); ]);
let global = ctx let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str);
.module
.add_global(arr_ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val); global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
@ -754,8 +711,7 @@ impl InnerResolver {
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
.add_global(ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
@ -763,15 +719,11 @@ impl InnerResolver {
} }
} }
let val: Result<Option<Vec<_>>, _> = elements let val: Result<Option<Vec<_>>, _> =
.iter() elements.iter().map(|elem| self.get_obj_value(py, elem, ctx, generator)).collect();
.map(|elem| self.get_obj_value(py, elem, ctx, generator))
.collect();
let val = val?.unwrap(); let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false); let val = ctx.ctx.const_struct(&val, false);
let global = ctx let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
.module
.add_global(ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val); global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
} else { } else {
@ -793,8 +745,7 @@ impl InnerResolver {
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
.add_global(ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
@ -802,10 +753,8 @@ impl InnerResolver {
} }
} }
// should be classes // should be classes
let definition = top_level_defs let definition =
.get(self.pyid_to_def.read().get(&ty_id).unwrap().0) top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read();
.unwrap()
.read();
if let TopLevelDef::Class { fields, .. } = &*definition { if let TopLevelDef::Class { fields, .. } = &*definition {
let values: Result<Option<Vec<_>>, _> = fields let values: Result<Option<Vec<_>>, _> = fields
.iter() .iter()
@ -816,9 +765,7 @@ impl InnerResolver {
let values = values?; let values = values?;
if let Some(values) = values { if let Some(values) = values {
let val = ty.const_named_struct(&values); let val = ty.const_named_struct(&values);
let global = ctx let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
.module
.add_global(ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val); global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
} else { } else {
@ -835,13 +782,9 @@ impl InnerResolver {
py: Python, py: Python,
obj: &PyAny, obj: &PyAny,
) -> PyResult<Result<SymbolValue, String>> { ) -> PyResult<Result<SymbolValue, String>> {
let ty_id: u64 = self let ty_id: u64 =
.helper self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?;
.id_fn Ok(if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?
.extract(py)?;
Ok(
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
let val: i32 = obj.extract()?; let val: i32 = obj.extract()?;
Ok(SymbolValue::I32(val)) Ok(SymbolValue::I32(val))
} else if ty_id == self.primitive_ids.int64 { } else if ty_id == self.primitive_ids.int64 {
@ -855,10 +798,8 @@ impl InnerResolver {
Ok(SymbolValue::Double(val)) Ok(SymbolValue::Double(val))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?; let elements: &PyTuple = obj.cast_as()?;
let elements: Result<Result<Vec<_>, String>, _> = elements let elements: Result<Result<Vec<_>, String>, _> =
.iter() elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect();
.map(|elem| self.get_default_param_obj_value(py, elem))
.collect();
let elements = match elements? { let elements = match elements? {
Ok(el) => el, Ok(el) => el,
Err(err) => return Ok(Err(err)), Err(err) => return Ok(Err(err)),
@ -866,8 +807,7 @@ impl InnerResolver {
Ok(SymbolValue::Tuple(elements)) Ok(SymbolValue::Tuple(elements))
} else { } else {
Err("only primitives values and tuple can be default parameter value".into()) Err("only primitives values and tuple can be default parameter value".into())
}, })
)
} }
} }
@ -882,12 +822,8 @@ impl SymbolResolver for Resolver {
for (key, val) in members.iter() { for (key, val) in members.iter() {
let key: &str = key.extract()?; let key: &str = key.extract()?;
if key == id.to_string() { if key == id.to_string() {
sym_value = Some( sym_value =
self.0 Some(self.0.get_default_param_obj_value(py, val).unwrap().unwrap());
.get_default_param_obj_value(py, val)
.unwrap()
.unwrap(),
);
break; break;
} }
} }
@ -992,7 +928,8 @@ impl SymbolResolver for Resolver {
id_to_def.get(&id).cloned().ok_or_else(|| "".to_string()) id_to_def.get(&id).cloned().ok_or_else(|| "".to_string())
} }
.or_else(|_| { .or_else(|_| {
let py_id = self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?; let py_id =
self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?;
let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or(format!( let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or(format!(
"`{}` is not registered in nac3, did you forgot to add @nac3?", "`{}` is not registered in nac3, did you forgot to add @nac3?",
id id
@ -1008,8 +945,9 @@ impl SymbolResolver for Resolver {
*id *id
} else { } else {
let id = Python::with_gil(|py| -> PyResult<i32> { let id = Python::with_gil(|py| -> PyResult<i32> {
self.0.helper.store_str.call1(py, (s, ))?.extract(py) self.0.helper.store_str.call1(py, (s,))?.extract(py)
}).unwrap(); })
.unwrap();
string_store.insert(s.into(), id); string_store.insert(s.into(), id);
id id
} }

View File

@ -1,5 +1,5 @@
use nac3core::codegen::CodeGenContext;
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}; use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering};
use nac3core::codegen::CodeGenContext;
pub trait TimeFns { pub trait TimeFns {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx>; fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx>;
@ -19,41 +19,23 @@ impl TimeFns for NowPinningTimeFns64 {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr =
now, ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr");
i32_type.ptr_type(AddressSpace::Generic),
"now_hiptr"
);
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep")
now_hiptr,
&[i32_type.const_int(2, false)],
"now_gep",
)
}; };
if let ( if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = (
BasicValueEnum::IntValue(now_hi),
BasicValueEnum::IntValue(now_lo)
) = (
ctx.builder.build_load(now_hiptr, "now_hi"), ctx.builder.build_load(now_hiptr, "now_hi"),
ctx.builder.build_load(now_loptr, "now_lo") ctx.builder.build_load(now_loptr, "now_lo"),
) { ) {
let zext_hi = ctx.builder.build_int_z_extend( let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
now_hi,
i64_type,
"now_zext_hi"
);
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi = ctx.builder.build_left_shift(
zext_hi, zext_hi,
i64_type.const_int(32, false), i64_type.const_int(32, false),
"now_shifted_zext_hi" "now_shifted_zext_hi",
);
let zext_lo = ctx.builder.build_int_z_extend(
now_lo,
i64_type,
"now_zext_lo"
); );
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
ctx.builder.build_or(shifted_hi, zext_lo, "now_or").into() ctx.builder.build_or(shifted_hi, zext_lo, "now_or").into()
} else { } else {
unreachable!(); unreachable!();
@ -69,8 +51,7 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t { if let BasicValueEnum::IntValue(time) = t {
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx.builder.build_int_truncate(
ctx.builder ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
.build_right_shift(time, i64_32, false, "now_lshr"),
i32_type, i32_type,
"now_trunc", "now_trunc",
); );
@ -86,11 +67,7 @@ impl TimeFns for NowPinningTimeFns64 {
); );
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep")
now_hiptr,
&[i32_type.const_int(2, false)],
"now_gep",
)
}; };
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
@ -108,60 +85,48 @@ impl TimeFns for NowPinningTimeFns64 {
} }
} }
fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { fn emit_delay_mu<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr =
now, ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr");
i32_type.ptr_type(AddressSpace::Generic),
"now_hiptr"
);
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_loptr")
now_hiptr,
&[i32_type.const_int(2, false)],
"now_loptr",
)
}; };
if let ( if let (
BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_hi),
BasicValueEnum::IntValue(now_lo), BasicValueEnum::IntValue(now_lo),
BasicValueEnum::IntValue(dt) BasicValueEnum::IntValue(dt),
) = ( ) = (
ctx.builder.build_load(now_hiptr, "now_hi"), ctx.builder.build_load(now_hiptr, "now_hi"),
ctx.builder.build_load(now_loptr, "now_lo"), ctx.builder.build_load(now_loptr, "now_lo"),
dt dt,
) { ) {
let zext_hi = ctx.builder.build_int_z_extend( let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
now_hi,
i64_type,
"now_zext_hi"
);
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi = ctx.builder.build_left_shift(
zext_hi, zext_hi,
i64_type.const_int(32, false), i64_type.const_int(32, false),
"now_shifted_zext_hi" "now_shifted_zext_hi",
);
let zext_lo = ctx.builder.build_int_z_extend(
now_lo,
i64_type,
"now_zext_lo"
); );
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now_or"); let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now_or");
let time = ctx.builder.build_int_add(now_val, dt, "now_add"); let time = ctx.builder.build_int_add(now_val, dt, "now_add");
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx.builder.build_int_truncate(
ctx.builder ctx.builder.build_right_shift(
.build_right_shift(
time, time,
i64_type.const_int(32, false), i64_type.const_int(32, false),
false, false,
"now_lshr" "now_lshr",
), ),
i32_type, i32_type,
"now_trunc", "now_trunc",
@ -200,9 +165,7 @@ impl TimeFns for NowPinningTimeFns {
if let BasicValueEnum::IntValue(now_raw) = now_raw { if let BasicValueEnum::IntValue(now_raw) = now_raw {
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl");
let now_hi = ctx let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr");
.builder
.build_right_shift(now_raw, i64_32, false, "now_lshr");
ctx.builder.build_or(now_lo, now_hi, "now_or").into() ctx.builder.build_or(now_lo, now_hi, "now_or").into()
} else { } else {
unreachable!(); unreachable!();
@ -215,8 +178,7 @@ impl TimeFns for NowPinningTimeFns {
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t { if let BasicValueEnum::IntValue(time) = t {
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx.builder.build_int_truncate(
ctx.builder ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
.build_right_shift(time, i64_32, false, "now_lshr"),
i32_type, i32_type,
"now_trunc", "now_trunc",
); );
@ -232,11 +194,7 @@ impl TimeFns for NowPinningTimeFns {
); );
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep")
now_hiptr,
&[i32_type.const_int(1, false)],
"now_gep",
)
}; };
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
@ -254,7 +212,11 @@ impl TimeFns for NowPinningTimeFns {
} }
} }
fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { fn emit_delay_mu<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
@ -263,18 +225,13 @@ impl TimeFns for NowPinningTimeFns {
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now");
if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) {
(now_raw, dt)
{
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl");
let now_hi = ctx let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr");
.builder
.build_right_shift(now_raw, i64_32, false, "now_lshr");
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or"); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or");
let time = ctx.builder.build_int_add(now_val, dt, "now_add"); let time = ctx.builder.build_int_add(now_val, dt, "now_add");
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx.builder.build_int_truncate(
ctx.builder ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
.build_right_shift(time, i64_32, false, "now_lshr"),
i32_type, i32_type,
"now_trunc", "now_trunc",
); );
@ -286,11 +243,7 @@ impl TimeFns for NowPinningTimeFns {
); );
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep")
now_hiptr,
&[i32_type.const_int(1, false)],
"now_gep",
)
}; };
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
@ -315,33 +268,36 @@ pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns { impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> { fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> {
let now_mu = ctx let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
.module ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
.get_function("now_mu") });
.unwrap_or_else(|| ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)); ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap()
ctx.builder
.build_call(now_mu, &[], "now_mu")
.try_as_basic_value()
.left()
.unwrap()
} }
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) {
let at_mu = ctx let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| {
.module ctx.module.add_function(
.get_function("at_mu") "at_mu",
.unwrap_or_else(|| ctx.module.add_function("at_mu", ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), None)); ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
ctx.builder None,
.build_call(at_mu, &[t.into()], "at_mu"); )
});
ctx.builder.build_call(at_mu, &[t.into()], "at_mu");
} }
fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { fn emit_delay_mu<'ctx, 'a>(
let delay_mu = ctx &self,
.module ctx: &mut CodeGenContext<'ctx, 'a>,
.get_function("delay_mu") dt: BasicValueEnum<'ctx>,
.unwrap_or_else(|| ctx.module.add_function("delay_mu", ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), None)); ) {
ctx.builder let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
.build_call(delay_mu, &[dt.into()], "delay_mu"); ctx.module.add_function(
"delay_mu",
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu");
} }
} }

View File

@ -33,7 +33,7 @@ pub enum Primitive {
None, None,
Range, Range,
Str, Str,
Exception Exception,
} }
#[derive(Debug)] #[derive(Debug)]
@ -162,10 +162,16 @@ impl ConcreteTypeStore {
// here we should not have type vars, but some partial instantiated // here we should not have type vars, but some partial instantiated
// class methods can still have uninstantiated type vars, so // class methods can still have uninstantiated type vars, so
// filter out all the methods, as this will not affect codegen // filter out all the methods, as this will not affect codegen
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) { if let TypeEnum::TFunc(..) = &*unifier.get_ty(ty.0) {
None None
} else { } else {
Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))) Some((
*name,
(
self.from_unifier_type(unifier, primitives, ty.0, cache),
ty.1,
),
))
} }
}) })
.collect(), .collect(),
@ -246,16 +252,13 @@ impl ConcreteTypeStore {
.map(|(name, cty)| { .map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
}) })
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>(),
.into(),
params: params params: params
.iter() .iter()
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>(),
.into(),
}, },
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc( ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
FunSignature {
args: args args: args
.iter() .iter()
.map(|arg| FuncArg { .map(|arg| FuncArg {
@ -267,13 +270,9 @@ impl ConcreteTypeStore {
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars vars: vars
.iter() .iter()
.map(|(id, cty)| { .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
(*id, self.to_unifier_type(unifier, primitives, *cty, cache))
})
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
} }),
.into(),
),
}; };
let result = unifier.add_ty(result); let result = unifier.add_ty(result);
if let Some(ty) = cache.get(&cty).unwrap() { if let Some(ty) = cache.get(&cty).unwrap() {

View File

@ -3,9 +3,9 @@ use std::{collections::HashMap, convert::TryInto, iter::once};
use crate::{ use crate::{
codegen::{ codegen::{
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
stmt::gen_raise,
get_llvm_type, get_llvm_type,
irrt::*, irrt::*,
stmt::gen_raise,
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
@ -13,12 +13,14 @@ use crate::{
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
}; };
use inkwell::{ use inkwell::{
AddressSpace,
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PointerValue} values::{BasicValueEnum, FunctionValue, IntValue, PointerValue},
AddressSpace,
}; };
use itertools::{chain, izip, zip, Itertools}; use itertools::{chain, izip, zip, Itertools};
use nac3parser::ast::{self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef}; use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
};
use super::CodeGenerator; use super::CodeGenerator;
@ -40,7 +42,14 @@ pub fn get_subst_key(
vars.extend(fun_vars.iter()); vars.extend(fun_vars.iter());
let sorted = vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); let sorted = vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted sorted
.map(|id| unifier.internal_stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None)) .map(|id| {
unifier.internal_stringify(
vars[id],
&mut |id| id.to_string(),
&mut |id| id.to_string(),
&mut None,
)
})
.join(", ") .join(", ")
} }
@ -77,14 +86,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
index index
} }
pub fn gen_symbol_val(&mut self, generator: &mut dyn CodeGenerator, val: &SymbolValue) -> BasicValueEnum<'ctx> { pub fn gen_symbol_val(
&mut self,
generator: &mut dyn CodeGenerator,
val: &SymbolValue,
) -> BasicValueEnum<'ctx> {
match val { match val {
SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(),
SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(), SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(),
SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(), SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(),
SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(),
SymbolValue::Str(v) => { SymbolValue::Str(v) => {
let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); let str_ptr =
self.builder.build_global_string_ptr(v, "const").as_pointer_value().into();
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type();
ty.const_named_struct(&[str_ptr, size.into()]).into() ty.const_named_struct(&[str_ptr, size.into()]).into()
@ -125,7 +139,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) )
} }
pub fn gen_const(&mut self, generator: &mut dyn CodeGenerator, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { pub fn gen_const(
&mut self,
generator: &mut dyn CodeGenerator,
value: &Constant,
ty: Type,
) -> BasicValueEnum<'ctx> {
match value { match value {
Constant::Bool(v) => { Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.primitives.bool)); assert!(self.unifier.unioned(ty, self.primitives.bool));
@ -163,10 +182,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
if let Some(v) = self.const_strings.get(v) { if let Some(v) = self.const_strings.get(v) {
*v *v
} else { } else {
let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); let str_ptr =
self.builder.build_global_string_ptr(v, "const").as_pointer_value().into();
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str); let ty = self.get_llvm_type(generator, self.primitives.str);
let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); let val =
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
self.const_strings.insert(v.to_string(), val); self.const_strings.insert(v.to_string(), val);
val val
} }
@ -262,12 +283,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&self, &self,
fun: FunctionValue<'ctx>, fun: FunctionValue<'ctx>,
params: &[BasicValueEnum<'ctx>], params: &[BasicValueEnum<'ctx>],
call_name: &str call_name: &str,
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
if let Some(target) = self.unwind_target { if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name)); let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name));
let result = self.builder.build_invoke(fun, params, then_block, target, call_name).try_as_basic_value().left(); let result = self
.builder
.build_invoke(fun, params, then_block, target, call_name)
.try_as_basic_value()
.left();
self.builder.position_at_end(then_block); self.builder.position_at_end(then_block);
result result
} else { } else {
@ -279,7 +304,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
pub fn gen_string<G: CodeGenerator, S: Into<String>>( pub fn gen_string<G: CodeGenerator, S: Into<String>>(
&mut self, &mut self,
generator: &mut G, generator: &mut G,
s: S s: S,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str)
} }
@ -290,7 +315,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
name: &str, name: &str,
msg: BasicValueEnum<'ctx>, msg: BasicValueEnum<'ctx>,
params: [Option<IntValue<'ctx>>; 3], params: [Option<IntValue<'ctx>>; 3],
loc: Location loc: Location,
) { ) {
let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type();
let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into();
@ -302,13 +327,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let id = self.resolver.get_string_id(name); let id = self.resolver.get_string_id(name);
self.builder.build_store(id_ptr, int32.const_int(id as u64, false)); self.builder.build_store(id_ptr, int32.const_int(id as u64, false));
let ptr = self.builder.build_in_bounds_gep( let ptr = self.builder.build_in_bounds_gep(
zelf, &[zero, int32.const_int(5, false)], "exn.msg"); zelf,
&[zero, int32.const_int(5, false)],
"exn.msg",
);
self.builder.build_store(ptr, msg); self.builder.build_store(ptr, msg);
let i64_zero = self.ctx.i64_type().const_zero(); let i64_zero = self.ctx.i64_type().const_zero();
for (i, attr_ind) in [6, 7, 8].iter().enumerate() { for (i, attr_ind) in [6, 7, 8].iter().enumerate() {
let ptr = self.builder.build_in_bounds_gep( let ptr = self.builder.build_in_bounds_gep(
zelf, &[zero, int32.const_int(*attr_ind, false)], "exn.param"); zelf,
let val = params[i].map_or(i64_zero, |v| self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext")); &[zero, int32.const_int(*attr_ind, false)],
"exn.param",
);
let val = params[i].map_or(i64_zero, |v| {
self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext")
});
self.builder.build_store(ptr, val); self.builder.build_store(ptr, val);
} }
} }
@ -322,19 +355,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
err_name: &str, err_name: &str,
err_msg: &str, err_msg: &str,
params: [Option<IntValue<'ctx>>; 3], params: [Option<IntValue<'ctx>>; 3],
loc: Location loc: Location,
) { ) {
let i1 = self.ctx.bool_type(); let i1 = self.ctx.bool_type();
let i1_true = i1.const_all_ones(); let i1_true = i1.const_all_ones();
let expect_fun = self.module.get_function("llvm.expect.i1").unwrap_or_else(|| { let expect_fun = self.module.get_function("llvm.expect.i1").unwrap_or_else(|| {
self.module.add_function("llvm.expect", i1.fn_type(&[i1.into(), i1.into()], false), None) self.module.add_function(
"llvm.expect",
i1.fn_type(&[i1.into(), i1.into()], false),
None,
)
}); });
// we assume that the condition is most probably true, so the normal path is the most // we assume that the condition is most probably true, so the normal path is the most
// probable path // probable path
// even if this assumption is violated, it does not matter as exception unwinding is // even if this assumption is violated, it does not matter as exception unwinding is
// slow anyway... // slow anyway...
let cond = self.builder.build_call(expect_fun, &[cond.into(), i1_true.into()], "expect") let cond = self
.try_as_basic_value().left().unwrap().into_int_value(); .builder
.build_call(expect_fun, &[cond.into(), i1_true.into()], "expect")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value();
let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current_fun, "succ"); let then_block = self.ctx.append_basic_block(current_fun, "succ");
let exn_block = self.ctx.append_basic_block(current_fun, "fail"); let exn_block = self.ctx.append_basic_block(current_fun, "fail");
@ -400,7 +442,7 @@ pub fn gen_func_instance<'ctx, 'a>(
return Ok(sym.clone()); return Ok(sym.clone());
} }
let symbol = format!("{}.{}", name, instance_to_symbol.len()); let symbol = format!("{}.{}", name, instance_to_symbol.len());
instance_to_symbol.insert(key.clone(), symbol.clone()); instance_to_symbol.insert(key, symbol.clone());
let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id)); let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id));
let instance = instance_to_stmt.get(&key).unwrap(); let instance = instance_to_stmt.get(&key).unwrap();
@ -484,7 +526,10 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
} }
// default value handling // default value handling
for k in keys.into_iter() { for k in keys.into_iter() {
mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into()); mapping.insert(
k.name,
ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into(),
);
} }
// reorder the parameters // reorder the parameters
let mut real_params = let mut real_params =
@ -821,7 +866,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
// we should use memcpy for that instead of generating thousands of stores // we should use memcpy for that instead of generating thousands of stores
let elements = elts let elements = elts
.iter() .iter()
.map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) .map(|x| {
generator
.gen_expr(ctx, x)
.map(|v| v.unwrap().to_basic_value_enum(ctx, generator))
})
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let ty = if elements.is_empty() { let ty = if elements.is_empty() {
if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) { if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) {
@ -850,7 +899,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let element_val = elts let element_val = elts
.iter() .iter()
.map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) .map(|x| {
generator
.gen_expr(ctx, x)
.map(|v| v.unwrap().to_basic_value_enum(ctx, generator))
})
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec();
let tuple_ty = ctx.ctx.struct_type(&element_ty, false); let tuple_ty = ctx.ctx.struct_type(&element_ty, false);
@ -935,7 +988,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?,
ExprKind::UnaryOp { op, operand } => { ExprKind::UnaryOp { op, operand } => {
let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let ty = ctx.unifier.get_representative(operand.custom.unwrap());
let val = generator.gen_expr(ctx, operand)?.unwrap().to_basic_value_enum(ctx, generator); let val =
generator.gen_expr(ctx, operand)?.unwrap().to_basic_value_enum(ctx, generator);
if ty == ctx.primitives.bool { if ty == ctx.primitives.bool {
let val = val.into_int_value(); let val = val.into_int_value();
match op { match op {
@ -1074,8 +1128,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
phi.as_basic_value().into() phi.as_basic_value().into()
} }
ExprKind::Call { func, args, keywords } => { ExprKind::Call { func, args, keywords } => {
let mut params = let mut params = args
args.iter().map(|arg| Ok((None, generator.gen_expr(ctx, arg)?.unwrap())) as Result<_, String>) .iter()
.map(|arg| Ok((None, generator.gen_expr(ctx, arg)?.unwrap())) as Result<_, String>)
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let kw_iter = keywords.iter().map(|kw| { let kw_iter = keywords.iter().map(|kw| {
Ok(( Ok((
@ -1101,7 +1156,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
match &func.node { match &func.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
// TODO: handle primitive casts and function pointers // TODO: handle primitive casts and function pointers
let fun = ctx.resolver.get_identifier_def(*id).map_err(|e| format!("{} (at {})", e, func.location))?; let fun = ctx
.resolver
.get_identifier_def(*id)
.map_err(|e| format!("{} (at {})", e, func.location))?;
return Ok(generator return Ok(generator
.gen_call(ctx, None, (&signature, fun), params)? .gen_call(ctx, None, (&signature, fun), params)?
.map(|v| v.into())); .map(|v| v.into()));
@ -1187,24 +1245,47 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
); );
res_array_ret.into() res_array_ret.into()
} else { } else {
let len = ctx.build_gep_and_load(v, &[zero, int32.const_int(1, false)]) let len = ctx
.build_gep_and_load(v, &[zero, int32.const_int(1, false)])
.into_int_value(); .into_int_value();
let raw_index = generator let raw_index = generator
.gen_expr(ctx, slice)? .gen_expr(ctx, slice)?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator) .to_basic_value_enum(ctx, generator)
.into_int_value(); .into_int_value();
let raw_index = ctx.builder.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext"); let raw_index = ctx.builder.build_int_s_extend(
raw_index,
generator.get_size_type(ctx.ctx),
"sext",
);
// handle negative index // handle negative index
let is_negative = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, raw_index, let is_negative = ctx.builder.build_int_compare(
generator.get_size_type(ctx.ctx).const_zero(), "is_neg"); inkwell::IntPredicate::SLT,
raw_index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
);
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted");
let index = ctx.builder.build_select(is_negative, adjusted, raw_index, "index").into_int_value(); let index = ctx
.builder
.build_select(is_negative, adjusted, raw_index, "index")
.into_int_value();
// unsigned less than is enough, because negative index after adjustment is // unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp) // bigger than the length (for unsigned cmp)
let bound_check = ctx.builder.build_int_compare(inkwell::IntPredicate::ULT, index, len, "inbound"); let bound_check = ctx.builder.build_int_compare(
ctx.make_assert(generator, bound_check, "0:IndexError", "index {0} out of bounds 0:{1}", inkwell::IntPredicate::ULT,
[Some(raw_index), Some(len), None], expr.location); index,
len,
"inbound",
);
ctx.make_assert(
generator,
bound_check,
"0:IndexError",
"index {0} out of bounds 0:{1}",
[Some(raw_index), Some(len), None],
expr.location,
);
ctx.build_gep_and_load(arr_ptr, &[index]) ctx.build_gep_and_load(arr_ptr, &[index])
} }
} else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { } else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) {

View File

@ -118,8 +118,11 @@ pub trait CodeGenerator {
/// Generate code for a while expression. /// Generate code for a while expression.
/// Return true if the while loop must early return /// Return true if the while loop must early return
fn gen_while<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt<Option<Type>>) fn gen_while<'ctx, 'a>(
-> Result<(), String> &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
@ -128,8 +131,11 @@ pub trait CodeGenerator {
/// Generate code for a while expression. /// Generate code for a while expression.
/// Return true if the while loop must early return /// Return true if the while loop must early return
fn gen_for<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt<Option<Type>>) fn gen_for<'ctx, 'a>(
-> Result<(), String> &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
@ -138,16 +144,22 @@ pub trait CodeGenerator {
/// Generate code for an if expression. /// Generate code for an if expression.
/// Return true if the statement must early return /// Return true if the statement must early return
fn gen_if<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt<Option<Type>>) fn gen_if<'ctx, 'a>(
-> Result<(), String> &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_if(self, ctx, stmt) gen_if(self, ctx, stmt)
} }
fn gen_with<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt<Option<Type>>) fn gen_with<'ctx, 'a>(
-> Result<(), String> &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
@ -156,8 +168,11 @@ pub trait CodeGenerator {
/// Generate code for a statement /// Generate code for a statement
/// Return true if the statement must early return /// Return true if the statement must early return
fn gen_stmt<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt<Option<Type>>) fn gen_stmt<'ctx, 'a>(
-> Result<(), String> &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {

View File

@ -215,7 +215,8 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>(
}); });
let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator); let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator);
Ok(ctx.builder Ok(ctx
.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind") .build_call(func, &[i.into(), length.into()], "bounded_ind")
.try_as_basic_value() .try_as_basic_value()
.left() .left()

View File

@ -30,8 +30,8 @@ use std::thread;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
mod generator; mod generator;
pub mod stmt;
pub mod irrt; pub mod irrt;
pub mod stmt;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -274,12 +274,22 @@ fn get_llvm_type<'ctx>(
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read(); let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(obj_id.0).unwrap(); let definition = top_level_defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } = &*definition.read() let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } =
&*definition.read()
{ {
let struct_type = ctx.opaque_struct_type(&name.to_string()); let struct_type = ctx.opaque_struct_type(&name.to_string());
let fields = fields_list let fields = fields_list
.iter() .iter()
.map(|f| get_llvm_type(ctx, generator, unifier, top_level, type_cache, fields[&f.0].0)) .map(|f| {
get_llvm_type(
ctx,
generator,
unifier,
top_level,
type_cache,
fields[&f.0].0,
)
})
.collect_vec(); .collect_vec();
struct_type.set_body(&fields, false); struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::Generic).into() struct_type.ptr_type(AddressSpace::Generic).into()
@ -298,9 +308,12 @@ fn get_llvm_type<'ctx>(
} }
TList { ty } => { TList { ty } => {
// a struct with an integer and a pointer to an array // a struct with an integer and a pointer to an array
let element_type = get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty); let element_type =
let fields = get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty);
[element_type.ptr_type(AddressSpace::Generic).into(), generator.get_size_type(ctx).into()]; let fields = [
element_type.ptr_type(AddressSpace::Generic).into(),
generator.get_size_type(ctx).into(),
];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
} }
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
@ -331,14 +344,17 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
// this should be unification between variables and concrete types // this should be unification between variables and concrete types
// and should not cause any problem... // and should not cause any problem...
let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache); let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache);
unifier.unify(*a, b).or_else(|err| { unifier
.unify(*a, b)
.or_else(|err| {
if matches!(&*unifier.get_ty(*a), TypeEnum::TRigidVar { .. }) { if matches!(&*unifier.get_ty(*a), TypeEnum::TRigidVar { .. }) {
unifier.replace_rigid_var(*a, b); unifier.replace_rigid_var(*a, b);
Ok(()) Ok(())
} else { } else {
Err(err) Err(err)
} }
}).unwrap() })
.unwrap()
} }
// rebuild primitive store with unique representatives // rebuild primitive store with unique representatives
@ -367,10 +383,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
str_type.set_body(&fields, false); str_type.set_body(&fields, false);
str_type.into() str_type.into()
}), }),
( (primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into()),
primitives.range,
context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into(),
),
] ]
.iter() .iter()
.cloned() .cloned()
@ -380,17 +393,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
let int32 = context.i32_type().into(); let int32 = context.i32_type().into();
let int64 = context.i64_type().into(); let int64 = context.i64_type().into();
let str_ty = *type_cache.get(&primitives.str).unwrap(); let str_ty = *type_cache.get(&primitives.str).unwrap();
let fields = [ let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
int32,
str_ty,
int32,
int32,
str_ty,
str_ty,
int64,
int64,
int64
];
exception.set_body(&fields, false); exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::Generic).into() exception.ptr_type(AddressSpace::Generic).into()
}); });
@ -414,14 +417,29 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
let params = args let params = args
.iter() .iter()
.map(|arg| { .map(|arg| {
get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty).into() get_llvm_type(
context,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
arg.ty,
)
.into()
}) })
.collect_vec(); .collect_vec();
let fn_type = if unifier.unioned(ret, primitives.none) { let fn_type = if unifier.unioned(ret, primitives.none) {
context.void_type().fn_type(&params, false) context.void_type().fn_type(&params, false)
} else { } else {
get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret) get_llvm_type(
context,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
ret,
)
.fn_type(&params, false) .fn_type(&params, false)
}; };
@ -445,7 +463,14 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
for (n, arg) in args.iter().enumerate() { for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param(n as u32).unwrap(); let param = fn_val.get_nth_param(n as u32).unwrap();
let alloca = builder.build_alloca( let alloca = builder.build_alloca(
get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty), get_llvm_type(
context,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
arg.ty,
),
&arg.name.to_string(), &arg.name.to_string(),
); );
builder.build_store(alloca, param); builder.build_store(alloca, param);

View File

@ -7,7 +7,7 @@ use super::{
use crate::{ use crate::{
codegen::expr::gen_binop_expr, codegen::expr::gen_binop_expr,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{Type, TypeEnum, FunSignature} typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -16,7 +16,9 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
IntPredicate::EQ, IntPredicate::EQ,
}; };
use nac3parser::ast::{ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant}; use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
};
use std::convert::TryFrom; use std::convert::TryFrom;
pub fn gen_var<'ctx, 'a>( pub fn gen_var<'ctx, 'a>(
@ -40,12 +42,16 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
// very similar to gen_expr, but we don't do an extra load at the end // very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples // and we flatten nested tuples
Ok(match &pattern.node { Ok(match &pattern.node {
ExprKind::Name { id, .. } => ctx.var_assignment.get(id).map(|v| Ok(v.0) as Result<_, String>).unwrap_or_else(|| { ExprKind::Name { id, .. } => {
ctx.var_assignment.get(id).map(|v| Ok(v.0) as Result<_, String>).unwrap_or_else(
|| {
let ptr_ty = ctx.get_llvm_type(generator, pattern.custom.unwrap()); let ptr_ty = ctx.get_llvm_type(generator, pattern.custom.unwrap());
let ptr = generator.gen_var_alloc(ctx, ptr_ty)?; let ptr = generator.gen_var_alloc(ctx, ptr_ty)?;
ctx.var_assignment.insert(*id, (ptr, None, 0)); ctx.var_assignment.insert(*id, (ptr, None, 0));
Ok(ptr) Ok(ptr)
})?, },
)?
}
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = generator.gen_expr(ctx, value)?.unwrap().to_basic_value_enum(ctx, generator); let val = generator.gen_expr(ctx, value)?.unwrap().to_basic_value_enum(ctx, generator);
@ -94,7 +100,7 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
Ok(match &target.node { match &target.node {
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
if let BasicValueEnum::StructValue(v) = value.to_basic_value_enum(ctx, generator) { if let BasicValueEnum::StructValue(v) = value.to_basic_value_enum(ctx, generator) {
for (i, elt) in elts.iter().enumerate() { for (i, elt) in elts.iter().enumerate() {
@ -120,9 +126,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
let (start, end, step) = let (start, end, step) =
handle_slice_indices(lower, upper, step, ctx, generator, ls)?; handle_slice_indices(lower, upper, step, ctx, generator, ls)?;
let value = value.to_basic_value_enum(ctx, generator).into_pointer_value(); let value = value.to_basic_value_enum(ctx, generator).into_pointer_value();
let ty = if let TypeEnum::TList { ty } = let ty =
&*ctx.unifier.get_ty(target.custom.unwrap()) if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) {
{
ctx.get_llvm_type(generator, *ty) ctx.get_llvm_type(generator, *ty)
} else { } else {
unreachable!() unreachable!()
@ -153,7 +158,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
let val = value.to_basic_value_enum(ctx, generator); let val = value.to_basic_value_enum(ctx, generator);
ctx.builder.build_store(ptr, val); ctx.builder.build_store(ptr, val);
} }
}) };
Ok(())
} }
pub fn gen_for<'ctx, 'a, G: CodeGenerator>( pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
@ -420,10 +426,10 @@ pub fn get_builtins<'ctx, 'a, G: CodeGenerator>(
) -> FunctionValue<'ctx> { ) -> FunctionValue<'ctx> {
ctx.module.get_function(symbol).unwrap_or_else(|| { ctx.module.get_function(symbol).unwrap_or_else(|| {
let ty = match symbol { let ty = match symbol {
"__artiq_raise" => ctx.ctx.void_type().fn_type( "__artiq_raise" => ctx
&[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], .ctx
false, .void_type()
), .fn_type(&[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], false),
"__artiq_resume" => ctx.ctx.void_type().fn_type(&[], false), "__artiq_resume" => ctx.ctx.void_type().fn_type(&[], false),
"__artiq_end_catch" => ctx.ctx.void_type().fn_type(&[], false), "__artiq_end_catch" => ctx.ctx.void_type().fn_type(&[], false),
_ => unimplemented!(), _ => unimplemented!(),
@ -444,7 +450,7 @@ pub fn exn_constructor<'ctx, 'a>(
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId), _fun: (&FunSignature, DefinitionId),
mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let (zelf_ty, zelf) = obj.unwrap(); let (zelf_ty, zelf) = obj.unwrap();
let zelf = zelf.to_basic_value_enum(ctx, generator).into_pointer_value(); let zelf = zelf.to_basic_value_enum(ctx, generator).into_pointer_value();
@ -459,19 +465,16 @@ pub fn exn_constructor<'ctx, 'a>(
}; };
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read(); let def = defs[zelf_id].read();
let zelf_name = if let TopLevelDef::Class { name, .. } = &*def { let zelf_name =
*name if let TopLevelDef::Class { name, .. } = &*def { *name } else { unreachable!() };
} else {
unreachable!()
};
let exception_name = format!("0:{}", zelf_name); let exception_name = format!("0:{}", zelf_name);
unsafe { unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id");
let id = ctx.resolver.get_string_id(&exception_name); let id = ctx.resolver.get_string_id(&exception_name);
ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)); ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false));
let empty_string = ctx.gen_const(generator, &Constant::Str("".into()), ctx.primitives.str); let empty_string = ctx.gen_const(generator, &Constant::Str("".into()), ctx.primitives.str);
let ptr = ctx.builder.build_in_bounds_gep( let ptr =
zelf, &[zero, int32.const_int(5, false)], "exn.msg"); ctx.builder.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg");
let msg = if !args.is_empty() { let msg = if !args.is_empty() {
args.remove(0).1.to_basic_value_enum(ctx, generator) args.remove(0).1.to_basic_value_enum(ctx, generator)
} else { } else {
@ -485,19 +488,28 @@ pub fn exn_constructor<'ctx, 'a>(
ctx.ctx.i64_type().const_zero().into() ctx.ctx.i64_type().const_zero().into()
}; };
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, &[zero, int32.const_int(*i, false)], "exn.param"); zelf,
&[zero, int32.const_int(*i, false)],
"exn.param",
);
ctx.builder.build_store(ptr, value); ctx.builder.build_store(ptr, value);
} }
// set file, func to empty string // set file, func to empty string
for i in [1, 4].iter() { for i in [1, 4].iter() {
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, &[zero, int32.const_int(*i, false)], "exn.str"); zelf,
&[zero, int32.const_int(*i, false)],
"exn.str",
);
ctx.builder.build_store(ptr, empty_string); ctx.builder.build_store(ptr, empty_string);
} }
// set ints to zero // set ints to zero
for i in [2, 3].iter() { for i in [2, 3].iter() {
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, &[zero, int32.const_int(*i, false)], "exn.ints"); zelf,
&[zero, int32.const_int(*i, false)],
"exn.ints",
);
ctx.builder.build_store(ptr, zero); ctx.builder.build_store(ptr, zero);
} }
} }
@ -515,17 +527,33 @@ pub fn gen_raise<'ctx, 'a, G: CodeGenerator>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let exception = exception.into_pointer_value(); let exception = exception.into_pointer_value();
let file_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr"); let file_ptr = ctx.builder.build_in_bounds_gep(
exception,
&[zero, int32.const_int(1, false)],
"file_ptr",
);
let filename = ctx.gen_string(generator, loc.file.0); let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename); ctx.builder.build_store(file_ptr, filename);
let row_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr"); let row_ptr = ctx.builder.build_in_bounds_gep(
exception,
&[zero, int32.const_int(2, false)],
"row_ptr",
);
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)); ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false));
let col_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr"); let col_ptr = ctx.builder.build_in_bounds_gep(
exception,
&[zero, int32.const_int(3, false)],
"col_ptr",
);
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)); ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false));
let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap());
let name_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr"); let name_ptr = ctx.builder.build_in_bounds_gep(
exception,
&[zero, int32.const_int(4, false)],
"name_ptr",
);
ctx.builder.build_store(name_ptr, fun_name); ctx.builder.build_store(name_ptr, fun_name);
} }
@ -599,7 +627,11 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
for handler_node in handlers.iter() { for handler_node in handlers.iter() {
let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node;
// none or Exception // none or Exception
if type_.is_none() || ctx.unifier.unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) { if type_.is_none()
|| ctx
.unifier
.unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception)
{
clauses.push(None); clauses.push(None);
found_catch_all = true; found_catch_all = true;
break; break;
@ -928,7 +960,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>(
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {
if let Some(exc) = exc { if let Some(exc) = exc {
let exc = generator.gen_expr(ctx, exc)?.unwrap().to_basic_value_enum(ctx, generator); let exc =
generator.gen_expr(ctx, exc)?.unwrap().to_basic_value_enum(ctx, generator);
gen_raise(generator, ctx, Some(&exc), stmt.location); gen_raise(generator, ctx, Some(&exc), stmt.location);
} else { } else {
gen_raise(generator, ctx, None, stmt.location); gen_raise(generator, ctx, None, stmt.location);

View File

@ -34,7 +34,10 @@ impl Resolver {
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> { fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!() unimplemented!()
} }
@ -57,7 +60,11 @@ impl SymbolResolver for Resolver {
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> {
self.id_to_def.read().get(&id).cloned().ok_or_else(|| format!("cannot find symbol `{}`", id)) self.id_to_def
.read()
.get(&id)
.cloned()
.ok_or_else(|| format!("cannot find symbol `{}`", id))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -118,7 +125,7 @@ fn test_primitives() {
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
defined_identifiers: identifiers.clone(), defined_identifiers: identifiers.clone(),
in_handler: false in_handler: false,
}; };
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
@ -263,7 +270,7 @@ fn test_simple_call() {
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
defined_identifiers: identifiers.clone(), defined_identifiers: identifiers.clone(),
in_handler: false in_handler: false,
}; };
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty); inferencer.variable_mapping.insert("foo".into(), fun_ty);

View File

@ -1,7 +1,8 @@
use std::{collections::HashMap, fmt::Display};
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, fmt::Display};
use crate::typecheck::typedef::TypeEnum;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
@ -13,7 +14,6 @@ use crate::{
typedef::{Type, Unifier}, typedef::{Type, Unifier},
}, },
}; };
use crate::typecheck::typedef::TypeEnum;
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue};
use itertools::{chain, izip}; use itertools::{chain, izip};
use nac3parser::ast::{Expr, Location, StrRef}; use nac3parser::ast::{Expr, Location, StrRef};
@ -36,11 +36,13 @@ impl Display for SymbolValue {
SymbolValue::I64(i) => write!(f, "int64({})", i), SymbolValue::I64(i) => write!(f, "int64({})", i),
SymbolValue::Str(s) => write!(f, "\"{}\"", s), SymbolValue::Str(s) => write!(f, "\"{}\"", s),
SymbolValue::Double(d) => write!(f, "{}", d), SymbolValue::Double(d) => write!(f, "{}", d),
SymbolValue::Bool(b) => if *b { SymbolValue::Bool(b) => {
if *b {
write!(f, "True") write!(f, "True")
} else { } else {
write!(f, "False") write!(f, "False")
}, }
}
SymbolValue::Tuple(t) => { SymbolValue::Tuple(t) => {
write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join(", ")) write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join(", "))
} }
@ -203,7 +205,8 @@ pub fn parse_type_annotation<T>(
let fields = chain( let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))), fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))), methods.iter().map(|(k, v, _)| (*k, (*v, false))),
).collect(); )
.collect();
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj {
obj_id, obj_id,
fields, fields,
@ -214,7 +217,8 @@ pub fn parse_type_annotation<T>(
} }
} }
Err(e) => { Err(e) => {
let ty = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) let ty = resolver
.get_symbol_type(unifier, top_level_defs, primitives, *id)
.map_err(|_| format!("Unknown type annotation at {}: {}", loc, e))?; .map_err(|_| format!("Unknown type annotation at {}: {}", loc, e))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
@ -256,8 +260,7 @@ pub fn parse_type_annotation<T>(
vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?] vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?]
}; };
let obj_id = resolver let obj_id = resolver.get_identifier_def(*id)?;
.get_identifier_def(*id)?;
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
@ -287,11 +290,7 @@ pub fn parse_type_annotation<T>(
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, false)) (*attr, (ty, false))
})); }));
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
obj_id,
fields,
params: subst,
}))
} else { } else {
Err("Cannot use function name as type".into()) Err("Cannot use function name as type".into())
} }
@ -338,7 +337,7 @@ impl dyn SymbolResolver + Send + Sync {
} }
}, },
&mut |id| format!("var{}", id), &mut |id| format!("var{}", id),
&mut None &mut None,
) )
} }
} }

View File

@ -1,14 +1,13 @@
use super::*; use super::*;
use crate::{ use crate::{
codegen::{expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor}, codegen::{
expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor,
},
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
}; };
use inkwell::{FloatPredicate, IntPredicate}; use inkwell::{FloatPredicate, IntPredicate};
type BuiltinInfo = ( type BuiltinInfo = (Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>, &'static [&'static str]);
Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>,
&'static [&'static str]
);
pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let int32 = primitives.0.int32; let int32 = primitives.0.int32;
@ -17,7 +16,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let boolean = primitives.0.bool; let boolean = primitives.0.bool;
let range = primitives.0.range; let range = primitives.0.range;
let string = primitives.0.str; let string = primitives.0.str;
let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean], Some("N".into()), None); let num_ty = primitives.1.get_fresh_var_with_range(
&[int32, int64, float, boolean],
Some("N".into()),
None,
);
let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect();
let exception_fields = vec![ let exception_fields = vec![
@ -34,32 +37,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let div_by_zero = primitives.1.add_ty(TypeEnum::TObj { let div_by_zero = primitives.1.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: DefinitionId(10),
fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(),
params: Default::default() params: Default::default(),
}); });
let index_error = primitives.1.add_ty(TypeEnum::TObj { let index_error = primitives.1.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(11), obj_id: DefinitionId(11),
fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(),
params: Default::default() params: Default::default(),
}); });
let exn_cons_args = vec![ let exn_cons_args = vec![
FuncArg { name: "msg".into(), ty: string, FuncArg {
default_value: Some(SymbolValue::Str("".into()))}, name: "msg".into(),
FuncArg { name: "param0".into(), ty: int64, ty: string,
default_value: Some(SymbolValue::I64(0))}, default_value: Some(SymbolValue::Str("".into())),
FuncArg { name: "param1".into(), ty: int64, },
default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param2".into(), ty: int64, FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param2".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
]; ];
let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: exn_cons_args.clone(), args: exn_cons_args.clone(),
ret: div_by_zero, ret: div_by_zero,
vars: Default::default() vars: Default::default(),
})); }));
let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: exn_cons_args, args: exn_cons_args,
ret: index_error, ret: index_error,
vars: Default::default() vars: Default::default(),
})); }));
let top_level_def_list = vec![ let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
@ -83,8 +86,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
None, None,
None, None,
))), ))),
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(3, None, "bool".into(), None, None))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(4, None, "none".into(), None, None))), 3,
None,
"bool".into(),
None,
None,
))),
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
4,
None,
"none".into(),
None,
None,
))),
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
5, 5,
None, None,
@ -92,7 +107,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
None, None,
None, None,
))), ))),
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(6, None, "str".into(), None, None))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
6,
None,
"str".into(),
None,
None,
))),
Arc::new(RwLock::new(TopLevelDef::Class { Arc::new(RwLock::new(TopLevelDef::Class {
name: "Exception".into(), name: "Exception".into(),
object_id: DefinitionId(7), object_id: DefinitionId(7),
@ -134,7 +155,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
methods: vec![("__init__".into(), div_by_zero_signature, DefinitionId(8))], methods: vec![("__init__".into(), div_by_zero_signature, DefinitionId(8))],
ancestors: vec![ ancestors: vec![
TypeAnnotation::CustomClass { id: DefinitionId(10), params: Default::default() }, TypeAnnotation::CustomClass { id: DefinitionId(10), params: Default::default() },
TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() },
], ],
constructor: Some(div_by_zero_signature), constructor: Some(div_by_zero_signature),
resolver: None, resolver: None,
@ -148,7 +169,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
methods: vec![("__init__".into(), index_error_signature, DefinitionId(9))], methods: vec![("__init__".into(), index_error_signature, DefinitionId(9))],
ancestors: vec![ ancestors: vec![
TypeAnnotation::CustomClass { id: DefinitionId(11), params: Default::default() }, TypeAnnotation::CustomClass { id: DefinitionId(11), params: Default::default() },
TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() },
], ],
constructor: Some(index_error_signature), constructor: Some(index_error_signature),
resolver: None, resolver: None,
@ -233,7 +254,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let boolean = ctx.primitives.bool; let boolean = ctx.primitives.bool;
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
Ok(if ctx.unifier.unioned(arg_ty, boolean) Ok(
if ctx.unifier.unioned(arg_ty, boolean)
|| ctx.unifier.unioned(arg_ty, int32) || ctx.unifier.unioned(arg_ty, int32)
{ {
Some( Some(
@ -259,7 +281,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some(val) Some(val)
} else { } else {
unreachable!() unreachable!()
}) },
)
}, },
)))), )))),
loc: None, loc: None,
@ -284,7 +307,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let float = ctx.primitives.float; let float = ctx.primitives.float;
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
Ok(if ctx.unifier.unioned(arg_ty, boolean) Ok(
if ctx.unifier.unioned(arg_ty, boolean)
|| ctx.unifier.unioned(arg_ty, int32) || ctx.unifier.unioned(arg_ty, int32)
|| ctx.unifier.unioned(arg_ty, int64) || ctx.unifier.unioned(arg_ty, int64)
{ {
@ -298,7 +322,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some(arg) Some(arg)
} else { } else {
unreachable!() unreachable!()
}) },
)
}, },
)))), )))),
loc: None, loc: None,
@ -315,7 +340,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let round_intrinsic = let round_intrinsic =
ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
@ -338,7 +364,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -353,7 +380,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let round_intrinsic = let round_intrinsic =
ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
@ -376,7 +404,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -404,7 +433,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let mut start = None; let mut start = None;
let mut stop = None; let mut stop = None;
let mut step = None; let mut step = None;
@ -452,7 +482,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ctx.builder.build_store(c, step); ctx.builder.build_store(c, step);
} }
Ok(Some(ptr.into())) Ok(Some(ptr.into()))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -467,9 +498,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
Ok(Some(args[0].1.clone().to_basic_value_enum(ctx, generator))) Ok(Some(args[0].1.clone().to_basic_value_enum(ctx, generator)))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -495,28 +528,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Ok(if ctx.unifier.unioned(arg_ty, boolean) { Ok(if ctx.unifier.unioned(arg_ty, boolean) {
Some(arg) Some(arg)
} else if ctx.unifier.unioned(arg_ty, int32) { } else if ctx.unifier.unioned(arg_ty, int32) {
Some(ctx.builder.build_int_compare( Some(
ctx.builder
.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
ctx.ctx.i32_type().const_zero(), ctx.ctx.i32_type().const_zero(),
arg.into_int_value(), arg.into_int_value(),
"bool", "bool",
).into()) )
.into(),
)
} else if ctx.unifier.unioned(arg_ty, int64) { } else if ctx.unifier.unioned(arg_ty, int64) {
Some(ctx.builder.build_int_compare( Some(
ctx.builder
.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
ctx.ctx.i64_type().const_zero(), ctx.ctx.i64_type().const_zero(),
arg.into_int_value(), arg.into_int_value(),
"bool", "bool",
).into()) )
.into(),
)
} else if ctx.unifier.unioned(arg_ty, float) { } else if ctx.unifier.unioned(arg_ty, float) {
let val = ctx.builder. let val = ctx
build_float_compare( .builder
.build_float_compare(
// UEQ as bool(nan) is True // UEQ as bool(nan) is True
FloatPredicate::UEQ, FloatPredicate::UEQ,
arg.into_float_value(), arg.into_float_value(),
ctx.ctx.f64_type().const_zero(), ctx.ctx.f64_type().const_zero(),
"bool" "bool",
).into(); )
.into();
Some(val) Some(val)
} else { } else {
unreachable!() unreachable!()
@ -537,7 +580,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let floor_intrinsic = let floor_intrinsic =
ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
@ -560,7 +604,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -575,7 +620,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let floor_intrinsic = let floor_intrinsic =
ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
@ -598,7 +644,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -613,7 +660,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let ceil_intrinsic = let ceil_intrinsic =
ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
@ -636,7 +684,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -651,7 +700,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
let ceil_intrinsic = let ceil_intrinsic =
ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
@ -674,24 +724,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
) )
.into(), .into(),
)) ))
})))), },
)))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
let arg_ty = primitives.1.get_fresh_var_with_range(&[list, primitives.0.range], Some("I".into()), None); let arg_ty = primitives.1.get_fresh_var_with_range(
&[list, primitives.0.range],
Some("I".into()),
None,
);
TopLevelDef::Function { TopLevelDef::Function {
name: "len".into(), name: "len".into(),
simple_name: "len".into(), simple_name: "len".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }],
name: "ls".into(),
ty: arg_ty.0,
default_value: None
}],
ret: int32, ret: int32,
vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)]
.into_iter()
.collect(),
})), })),
var_id: vec![arg_ty.1], var_id: vec![arg_ty.1],
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
@ -709,7 +762,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
} else { } else {
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let len = ctx.build_gep_and_load(arg.into_pointer_value(), &[zero, int32.const_int(1, false)]).into_int_value(); let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, int32.const_int(1, false)],
)
.into_int_value();
if len.get_type().get_bit_width() != 32 { if len.get_type().get_bit_width() != 32 {
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
} else { } else {
@ -720,7 +778,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
} }
})) })),
]; ];
let ast_list: Vec<Option<ast::Stmt<()>>> = let ast_list: Vec<Option<ast::Stmt<()>>> =
(0..top_level_def_list.len()).map(|_| None).collect(); (0..top_level_def_list.len()).map(|_| None).collect();
@ -742,6 +800,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"ceil", "ceil",
"ceil64", "ceil64",
"len", "len",
] ],
) )
} }

View File

@ -1,9 +1,9 @@
use nac3parser::ast::fold::Fold; use nac3parser::ast::fold::Fold;
use crate::{ use crate::{
typecheck::type_inferencer::{FunctionData, Inferencer},
codegen::{expr::get_subst_key, stmt::exn_constructor}, codegen::{expr::get_subst_key, stmt::exn_constructor},
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::type_inferencer::{FunctionData, Inferencer},
}; };
use super::*; use super::*;
@ -15,10 +15,7 @@ pub struct ComposerConfig {
impl Default for ComposerConfig { impl Default for ComposerConfig {
fn default() -> Self { fn default() -> Self {
ComposerConfig { ComposerConfig { kernel_ann: None, kernel_invariant_ann: "Invariant" }
kernel_ann: None,
kernel_invariant_ann: "Invariant"
}
} }
} }
@ -52,7 +49,7 @@ impl TopLevelComposer {
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new( pub fn new(
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>, builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
core_config: ComposerConfig core_config: ComposerConfig,
) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) { ) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) {
let mut primitives = Self::make_primitives(); let mut primitives = Self::make_primitives();
let (mut definition_ast_list, builtin_name_list) = builtins::get_builtins(&mut primitives); let (mut definition_ast_list, builtin_name_list) = builtins::get_builtins(&mut primitives);
@ -89,7 +86,8 @@ impl TopLevelComposer {
assert!(name == *simple_name); assert!(name == *simple_name);
builtin_ty.insert(name, *signature); builtin_ty.insert(name, *signature);
builtin_id.insert(name, DefinitionId(id)); builtin_id.insert(name, DefinitionId(id));
} else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def { } else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def
{
assert!(id == object_id.0); assert!(id == object_id.0);
assert!(type_vars.is_empty()); assert!(type_vars.is_empty());
if let Some(constructor) = constructor { if let Some(constructor) = constructor {
@ -377,7 +375,7 @@ impl TopLevelComposer {
unreachable!("must be both class") unreachable!("must be both class")
} }
} else { } else {
return Ok(()) return Ok(());
} }
}; };
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
@ -484,14 +482,17 @@ impl TopLevelComposer {
let unifier = self.unifier.borrow_mut(); let unifier = self.unifier.borrow_mut();
let primitive_types = self.primitives_ty; let primitive_types = self.primitives_ty;
let mut get_direct_parents = |class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| { let mut get_direct_parents =
|class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| {
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = {
if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = if let TopLevelDef::Class {
class_def.deref_mut() ancestors, resolver, object_id, type_vars, ..
} = class_def.deref_mut()
{ {
if let Some(ast::Located { if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, .. node: ast::StmtKind::ClassDef { bases, .. },
..
}) = class_ast }) = class_ast
{ {
(object_id, bases, ancestors, resolver, type_vars) (object_id, bases, ancestors, resolver, type_vars)
@ -570,7 +571,7 @@ impl TopLevelComposer {
if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() { if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() {
(ancestors, *object_id) (ancestors, *object_id)
} else { } else {
return Ok(()) return Ok(());
} }
}; };
ancestors_store.insert( ancestors_store.insert(
@ -614,11 +615,17 @@ impl TopLevelComposer {
.insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id));
// special case classes that inherit from Exception // special case classes that inherit from Exception
if class_ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { if class_ancestors
.iter()
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
{
// if inherited from Exception, the body should be a pass // if inherited from Exception, the body should be a pass
if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node { if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node {
for stmt in body.iter() { for stmt in body.iter() {
if matches!(stmt.node, ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }) { if matches!(
stmt.node,
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
) {
return Err("Classes inherited from exception should have no custom fields/methods".into()); return Err("Classes inherited from exception should have no custom fields/methods".into());
} }
} }
@ -629,7 +636,9 @@ impl TopLevelComposer {
} }
// deal with ancestor of Exception object // deal with ancestor of Exception object
if let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *self.definition_ast_list[7].0.write() { if let TopLevelDef::Class { name, ancestors, object_id, .. } =
&mut *self.definition_ast_list[7].0.write()
{
assert_eq!(*name, "Exception".into()); assert_eq!(*name, "Exception".into());
ancestors.push(make_self_type_annotation(&[], *object_id)); ancestors.push(make_self_type_annotation(&[], *object_id));
} else { } else {
@ -658,7 +667,7 @@ impl TopLevelComposer {
unifier, unifier,
primitives, primitives,
&mut type_var_to_concrete_def, &mut type_var_to_concrete_def,
(&self.keyword_list, &self.core_config) (&self.keyword_list, &self.core_config),
) { ) {
errors.insert(e); errors.insert(e);
} }
@ -740,7 +749,7 @@ impl TopLevelComposer {
x x
} else { } else {
// if let TopLevelDef::Function { name, .. } = `` // if let TopLevelDef::Function { name, .. } = ``
return Ok(()) return Ok(());
}; };
if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } =
@ -760,7 +769,9 @@ impl TopLevelComposer {
// make sure no duplicate parameter // make sure no duplicate parameter
let mut defined_paramter_name: HashSet<_> = HashSet::new(); let mut defined_paramter_name: HashSet<_> = HashSet::new();
for x in args.args.iter() { for x in args.args.iter() {
if !defined_paramter_name.insert(x.node.arg) || keyword_list.contains(&x.node.arg) { if !defined_paramter_name.insert(x.node.arg)
|| keyword_list.contains(&x.node.arg)
{
return Err(format!( return Err(format!(
"top level function must have unique parameter names \ "top level function must have unique parameter names \
and names should not be the same as the keywords (at {})", and names should not be the same as the keywords (at {})",
@ -769,17 +780,21 @@ impl TopLevelComposer {
} }
} }
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> = args let arg_with_default: Vec<(
&ast::Located<ast::ArgData<()>>,
Option<&ast::Expr>,
)> = args
.args .args
.iter() .iter()
.rev() .rev()
.zip(args .zip(
.defaults args.defaults
.iter() .iter()
.rev() .rev()
.map(|x| -> Option<&ast::Expr> { Some(x) }) .map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)) .chain(std::iter::repeat(None)),
).collect_vec(); )
.collect_vec();
arg_with_default arg_with_default
.iter() .iter()
@ -839,16 +854,21 @@ impl TopLevelComposer {
default_value: match default { default_value: match default {
None => None, None => None,
Some(default) => Some({ Some(default) => Some({
let v = Self::parse_parameter_default_value(default, resolver)?; let v = Self::parse_parameter_default_value(
default, resolver,
)?;
Self::check_default_param_type( Self::check_default_param_type(
&v, &v,
&type_annotation, &type_annotation,
primitives_store, primitives_store,
unifier unifier,
).map_err(|err| format!("{} (at {})", err, x.location))?; )
.map_err(
|err| format!("{} (at {})", err, x.location),
)?;
v v
}) }),
} },
}) })
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
@ -910,18 +930,20 @@ impl TopLevelComposer {
.collect_vec() .collect_vec()
.as_slice() .as_slice()
); );
let function_ty = unifier.add_ty(TypeEnum::TFunc( let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } args: arg_types,
)); ret: return_ty,
unifier vars: function_var_map,
.unify(*dummy_ty, function_ty) }));
.map_err(|e| e.at(Some(function_ast.location)).to_display(unifier).to_string())?; unifier.unify(*dummy_ty, function_ty).map_err(|e| {
e.at(Some(function_ast.location)).to_display(unifier).to_string()
})?;
} else { } else {
unreachable!("must be both function"); unreachable!("must be both function");
} }
} else { } else {
// not top level function def, skip // not top level function def, skip
return Ok(()) return Ok(());
} }
Ok(()) Ok(())
}; };
@ -931,7 +953,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors.iter().join("\n----------\n")) return Err(errors.iter().join("\n----------\n"));
} }
Ok(()) Ok(())
} }
@ -1003,35 +1025,46 @@ impl TopLevelComposer {
let zelf: StrRef = "self".into(); let zelf: StrRef = "self".into();
for x in args.args.iter() { for x in args.args.iter() {
if !defined_paramter_name.insert(x.node.arg) if !defined_paramter_name.insert(x.node.arg)
|| (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) { || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf)
{
return Err(format!( return Err(format!(
"top level function must have unique parameter names \ "top level function must have unique parameter names \
and names should not be the same as the keywords (at {})", and names should not be the same as the keywords (at {})",
x.location x.location
)) ));
} }
} }
if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) { if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) {
return Err(format!("__init__ method must have a `self` parameter (at {})", b.location)); return Err(format!(
"__init__ method must have a `self` parameter (at {})",
b.location
));
} }
if !defined_paramter_name.contains(&zelf) { if !defined_paramter_name.contains(&zelf) {
return Err(format!("class method must have a `self` parameter (at {})", b.location)); return Err(format!(
"class method must have a `self` parameter (at {})",
b.location
));
} }
let mut result = Vec::new(); let mut result = Vec::new();
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> = args let arg_with_default: Vec<(
&ast::Located<ast::ArgData<()>>,
Option<&ast::Expr>,
)> = args
.args .args
.iter() .iter()
.rev() .rev()
.zip(args .zip(
.defaults args.defaults
.iter() .iter()
.rev() .rev()
.map(|x| -> Option<&ast::Expr> { Some(x) }) .map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)) .chain(std::iter::repeat(None)),
).collect_vec(); )
.collect_vec();
for (x, default) in arg_with_default.into_iter().rev() { for (x, default) in arg_with_default.into_iter().rev() {
let name = x.node.arg; let name = x.node.arg;
@ -1085,13 +1118,20 @@ impl TopLevelComposer {
return Err(format!("`self` parameter cannot take default value (at {})", x.location)); return Err(format!("`self` parameter cannot take default value (at {})", x.location));
} }
Some({ Some({
let v = Self::parse_parameter_default_value(default, class_resolver)?; let v = Self::parse_parameter_default_value(
Self::check_default_param_type(&v, &type_ann, primitives, unifier) default,
.map_err(|err| format!("{} (at {})", err, x.location))?; class_resolver,
)?;
Self::check_default_param_type(
&v, &type_ann, primitives, unifier,
)
.map_err(|err| {
format!("{} (at {})", err, x.location)
})?;
v v
}) })
} }
} },
}; };
// push the dummy type and the type annotation // push the dummy type and the type annotation
// into the list for later unification // into the list for later unification
@ -1162,14 +1202,17 @@ impl TopLevelComposer {
} else { } else {
unreachable!() unreachable!()
} }
let method_type = unifier.add_ty(TypeEnum::TFunc( let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: arg_types, ret: ret_type, vars: method_var_map } args: arg_types,
.into(), ret: ret_type,
)); vars: method_var_map,
}));
// unify now since function type is not in type annotation define // unify now since function type is not in type annotation define
// which should be fine since type within method_type will be subst later // which should be fine since type within method_type will be subst later
unifier.unify(method_dummy_ty, method_type).map_err(|e| e.to_display(unifier).to_string())?; unifier
.unify(method_dummy_ty, method_type)
.map_err(|e| e.to_display(unifier).to_string())?;
} }
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
if let ast::ExprKind::Name { id: attr, .. } = &target.node { if let ast::ExprKind::Name { id: attr, .. } = &target.node {
@ -1178,16 +1221,24 @@ impl TopLevelComposer {
// handle Kernel[T], KernelInvariant[T] // handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node { let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. } if matches!( ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node, &value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
) => (slice, false), ) =>
ast::ExprKind::Subscript { value, slice, .. } if matches!( {
(slice, false)
}
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node, &value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
) => (slice, true), ) =>
{
(slice, true)
}
_ if core_config.kernel_ann.is_none() => (annotation, true), _ if core_config.kernel_ann.is_none() => (annotation, true),
_ => continue // ignore fields annotated otherwise _ => continue, // ignore fields annotated otherwise
}; };
class_fields_def.push((*attr, dummy_field_type, mutable)); class_fields_def.push((*attr, dummy_field_type, mutable));
@ -1220,8 +1271,7 @@ impl TopLevelComposer {
} else { } else {
return Err(format!( return Err(format!(
"same class fields `{}` defined twice (at {})", "same class fields `{}` defined twice (at {})",
attr, attr, target.location
target.location
)); ));
} }
} else { } else {
@ -1233,10 +1283,12 @@ impl TopLevelComposer {
} }
ast::StmtKind::Pass { .. } => {} ast::StmtKind::Pass { .. } => {}
ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior
_ => return Err(format!( _ => {
return Err(format!(
"unsupported statement in class definition body (at {})", "unsupported statement in class definition body (at {})",
b.location b.location
)), ))
}
} }
} }
Ok(()) Ok(())
@ -1394,23 +1446,38 @@ impl TopLevelComposer {
primitives_ty, primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
)?; )?;
if ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { if ancestors
.iter()
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
{
// create constructor for these classes // create constructor for these classes
let string = primitives_ty.str; let string = primitives_ty.str;
let int64 = primitives_ty.int64; let int64 = primitives_ty.int64;
let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { name: "msg".into(), ty: string, FuncArg {
default_value: Some(SymbolValue::Str("".into()))}, name: "msg".into(),
FuncArg { name: "param0".into(), ty: int64, ty: string,
default_value: Some(SymbolValue::I64(0))}, default_value: Some(SymbolValue::Str("".into())),
FuncArg { name: "param1".into(), ty: int64, },
default_value: Some(SymbolValue::I64(0))}, FuncArg {
FuncArg { name: "param2".into(), ty: int64, name: "param0".into(),
default_value: Some(SymbolValue::I64(0))}, ty: int64,
default_value: Some(SymbolValue::I64(0)),
},
FuncArg {
name: "param1".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
},
FuncArg {
name: "param2".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
},
], ],
ret: self_type, ret: self_type,
vars: Default::default() vars: Default::default(),
})); }));
let cons_fun = TopLevelDef::Function { let cons_fun = TopLevelDef::Function {
name: format!("{}.{}", class_name, "__init__"), name: format!("{}.{}", class_name, "__init__"),
@ -1421,14 +1488,13 @@ impl TopLevelComposer {
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))),
loc: None loc: None,
}; };
constructors.push((i, signature, definition_extension.len())); constructors.push((i, signature, definition_extension.len()));
definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); definition_extension.push((Arc::new(RwLock::new(cons_fun)), None));
unifier unifier.unify(constructor.unwrap(), signature).map_err(|e| {
.unify(constructor.unwrap(), signature) e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string()
.map_err(|e| e.at(Some(ast.as_ref().unwrap().location)) })?;
.to_display(unifier).to_string())?;
return Ok(()); return Ok(());
} }
let mut init_id: Option<DefinitionId> = None; let mut init_id: Option<DefinitionId> = None;
@ -1439,7 +1505,9 @@ impl TopLevelComposer {
for (name, func_sig, id) in methods { for (name, func_sig, id) in methods {
if *name == init_str_id { if *name == init_str_id {
init_id = Some(*id); init_id = Some(*id);
if let TypeEnum::TFunc(FunSignature { args, vars, ..}) = unifier.get_ty(*func_sig).as_ref() { if let TypeEnum::TFunc(FunSignature { args, vars, .. }) =
unifier.get_ty(*func_sig).as_ref()
{
constructor_args.extend_from_slice(args); constructor_args.extend_from_slice(args);
type_vars.extend(vars); type_vars.extend(vars);
} else { } else {
@ -1449,17 +1517,18 @@ impl TopLevelComposer {
} }
(constructor_args, type_vars) (constructor_args, type_vars)
}; };
let contor_type = unifier.add_ty(TypeEnum::TFunc( let contor_type = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: contor_args, ret: self_type, vars: contor_type_vars } args: contor_args,
)); ret: self_type,
unifier vars: contor_type_vars,
.unify(constructor.unwrap(), contor_type) }));
.map_err(|e| e.at(Some(ast.as_ref().unwrap().location)).to_display(&unifier).to_string())?; unifier.unify(constructor.unwrap(), contor_type).map_err(|e| {
e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string()
})?;
// class field instantiation check // class field instantiation check
if let (Some(init_id), false) = (init_id, fields.is_empty()) { if let (Some(init_id), false) = (init_id, fields.is_empty()) {
let init_ast = let init_ast = definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
if *name != init_str_id { if *name != init_str_id {
unreachable!("must be init function here") unreachable!("must be init function here")
@ -1490,9 +1559,13 @@ impl TopLevelComposer {
} }
for (i, signature, id) in constructors.into_iter() { for (i, signature, id) in constructors.into_iter() {
if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() { if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write()
methods.push((init_str_id, signature, {
DefinitionId(self.definition_ast_list.len() + id))); methods.push((
init_str_id,
signature,
DefinitionId(self.definition_ast_list.len() + id),
));
} else { } else {
unreachable!() unreachable!()
} }
@ -1508,7 +1581,7 @@ impl TopLevelComposer {
let method_class = &mut self.method_class; let method_class = &mut self.method_class;
let mut analyze_2 = |id, def: &Arc<RwLock<TopLevelDef>>, ast: &Option<Stmt>| { let mut analyze_2 = |id, def: &Arc<RwLock<TopLevelDef>>, ast: &Option<Stmt>| {
if ast.is_none() { if ast.is_none() {
return Ok(()) return Ok(());
} }
let mut function_def = def.write(); let mut function_def = def.write();
if let TopLevelDef::Function { if let TopLevelDef::Function {
@ -1522,7 +1595,9 @@ impl TopLevelComposer {
.. ..
} = &mut *function_def } = &mut *function_def
{ {
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = unifier.get_ty(*signature).as_ref() { if let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
unifier.get_ty(*signature).as_ref()
{
// None if is not class method // None if is not class method
let uninst_self_type = { let uninst_self_type = {
if let Some(class_id) = method_class.get(&DefinitionId(id)) { if let Some(class_id) = method_class.get(&DefinitionId(id)) {
@ -1553,7 +1628,8 @@ impl TopLevelComposer {
.iter() .iter()
.map(|(_, ty)| { .map(|(_, ty)| {
unifier.get_instantiations(*ty).unwrap_or_else(|| { unifier.get_instantiations(*ty).unwrap_or_else(|| {
if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) { if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty)
{
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
no_ranges.push(rigid); no_ranges.push(rigid);
vec![rigid] vec![rigid]
@ -1588,14 +1664,13 @@ impl TopLevelComposer {
.collect_vec() .collect_vec()
}; };
let self_type = { let self_type = {
uninst_self_type uninst_self_type.clone().map(|(self_type, type_vars)| {
.clone()
.map(|(self_type, type_vars)| {
let subst_for_self = { let subst_for_self = {
let class_ty_var_ids = type_vars let class_ty_var_ids = type_vars
.iter() .iter()
.map(|x| { .map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x)
{
*id *id
} else { } else {
unreachable!("must be type var here"); unreachable!("must be type var here");
@ -1632,9 +1707,7 @@ impl TopLevelComposer {
defined_identifiers: identifiers.clone(), defined_identifiers: identifiers.clone(),
function_data: &mut FunctionData { function_data: &mut FunctionData {
resolver: resolver.as_ref().unwrap().clone(), resolver: resolver.as_ref().unwrap().clone(),
return_type: if unifier return_type: if unifier.unioned(inst_ret, primitives_ty.none) {
.unioned(inst_ret, primitives_ty.none)
{
None None
} else { } else {
Some(inst_ret) Some(inst_ret)
@ -1656,7 +1729,7 @@ impl TopLevelComposer {
primitives: primitives_ty, primitives: primitives_ty,
virtual_checks: &mut Vec::new(), virtual_checks: &mut Vec::new(),
calls: &mut calls, calls: &mut calls,
in_handler: false in_handler: false,
}; };
let fun_body = let fun_body =
@ -1696,7 +1769,10 @@ impl TopLevelComposer {
if let TypeEnum::TObj { obj_id, .. } = &*base { if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id *obj_id
} else { } else {
return Err(format!("Base type should be a class (at {})", loc)) return Err(format!(
"Base type should be a class (at {})",
loc
));
} }
}; };
let subtype_id = { let subtype_id = {
@ -1706,7 +1782,10 @@ impl TopLevelComposer {
} else { } else {
let base_repr = inferencer.unifier.stringify(*base); let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) return Err(format!(
"Expected a subtype of {}, but got {} (at {})",
base_repr, subtype_repr, loc
));
} }
}; };
let subtype_entry = defs[subtype_id.0].read(); let subtype_entry = defs[subtype_id.0].read();
@ -1716,7 +1795,10 @@ impl TopLevelComposer {
if m.is_none() { if m.is_none() {
let base_repr = inferencer.unifier.stringify(*base); let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) return Err(format!(
"Expected a subtype of {}, but got {} (at {})",
base_repr, subtype_repr, loc
));
} }
} else { } else {
unreachable!(); unreachable!();
@ -1748,12 +1830,7 @@ impl TopLevelComposer {
} }
instance_to_stmt.insert( instance_to_stmt.insert(
get_subst_key( get_subst_key(unifier, self_type, &subst, Some(insted_vars)),
unifier,
self_type,
&subst,
Some(insted_vars),
),
FunInstance { FunInstance {
body: Arc::new(fun_body), body: Arc::new(fun_body),
unifier_id: 0, unifier_id: 0,

View File

@ -1,32 +1,22 @@
use std::convert::TryInto; use std::convert::TryInto;
use nac3parser::ast::{Constant, Location};
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use nac3parser::ast::{Constant, Location};
use super::*; use super::*;
impl TopLevelDef { impl TopLevelDef {
pub fn to_string( pub fn to_string(&self, unifier: &mut Unifier) -> String {
&self,
unifier: &mut Unifier,
) -> String
{
match self { match self {
TopLevelDef::Class { TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => {
name, ancestors, fields, methods, type_vars, ..
} => {
let fields_str = fields let fields_str = fields
.iter() .iter()
.map(|(n, ty, _)| { .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
(n.to_string(), unifier.stringify(*ty))
})
.collect_vec(); .collect_vec();
let methods_str = methods let methods_str = methods
.iter() .iter()
.map(|(n, ty, id)| { .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id))
(n.to_string(), unifier.stringify(*ty), *id)
})
.collect_vec(); .collect_vec();
format!( format!(
"Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}",
@ -57,38 +47,38 @@ impl TopLevelComposer {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: DefinitionId(1),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: DefinitionId(2),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: DefinitionId(4),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: DefinitionId(6),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: DefinitionId(7),
@ -102,8 +92,10 @@ impl TopLevelComposer {
("__param0__".into(), (int64, true)), ("__param0__".into(), (int64, true)),
("__param1__".into(), (int64, true)), ("__param1__".into(), (int64, true)),
("__param2__".into(), (int64, true)), ("__param2__".into(), (int64, true)),
].into_iter().collect::<HashMap<_, _>>().into(), ]
params: HashMap::new().into(), .into_iter()
.collect::<HashMap<_, _>>(),
params: HashMap::new(),
}); });
let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception };
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
@ -117,7 +109,7 @@ impl TopLevelComposer {
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
name: StrRef, name: StrRef,
constructor: Option<Type>, constructor: Option<Type>,
loc: Option<Location> loc: Option<Location>,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Class { TopLevelDef::Class {
name, name,
@ -138,7 +130,7 @@ impl TopLevelComposer {
simple_name: StrRef, simple_name: StrRef,
ty: Type, ty: Type,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location> loc: Option<Location>,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Function { TopLevelDef::Function {
name, name,
@ -248,8 +240,11 @@ impl TopLevelComposer {
let this = this.as_ref(); let this = this.as_ref();
let other = unifier.get_ty(other); let other = unifier.get_ty(other);
let other = other.as_ref(); let other = other.as_ref();
if let (TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, ..}), if let (
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. })) = (this, other) { TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other)
{
// check args // check args
let args_ok = this_args let args_ok = this_args
.iter() .iter()
@ -362,11 +357,19 @@ impl TopLevelComposer {
Ok(result) Ok(result)
} }
pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result<SymbolValue, String> { pub fn parse_parameter_default_value(
default: &ast::Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
) -> Result<SymbolValue, String> {
parse_parameter_default_value(default, resolver) parse_parameter_default_value(default, resolver)
} }
pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> { pub fn check_default_param_type(
val: &SymbolValue,
ty: &TypeAnnotation,
primitive: &PrimitiveStore,
unifier: &mut Unifier,
) -> Result<(), String> {
let res = match val { let res = match val {
SymbolValue::Bool(..) => { SymbolValue::Bool(..) => {
if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) {
@ -430,33 +433,26 @@ impl TopLevelComposer {
} }
} }
pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result<SymbolValue, String> { pub fn parse_parameter_default_value(
default: &ast::Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
) -> Result<SymbolValue, String> {
fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, String> { fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, String> {
match val { match val {
Constant::Int(v) => { Constant::Int(v) => match v {
match v {
Some(v) => { Some(v) => {
if let Ok(v) = (*v).try_into() { if let Ok(v) = (*v).try_into() {
Ok(SymbolValue::I32(v)) Ok(SymbolValue::I32(v))
} else { } else {
Err(format!( Err(format!("integer value out of range at {}", loc))
"integer value out of range at {}",
loc
))
} }
}
None => Err(format!("integer value out of range at {}", loc)),
}, },
None => {
Err(format!(
"integer value out of range at {}",
loc
))
}
}
}
Constant::Float(v) => Ok(SymbolValue::Double(*v)), Constant::Float(v) => Ok(SymbolValue::Double(*v)),
Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), Constant::Bool(v) => Ok(SymbolValue::Bool(*v)),
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()? tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)), )),
_ => unimplemented!("this constant is not supported at {}", loc), _ => unimplemented!("this constant is not supported at {}", loc),
} }

View File

@ -34,7 +34,10 @@ impl ResolverInternal {
struct Resolver(Arc<ResolverInternal>); struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> { fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!() unimplemented!()
} }
@ -169,10 +172,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
{ {
let def = &*def.read(); let def = &*def.read();
if let TopLevelDef::Function { signature, name, .. } = def { if let TopLevelDef::Function { signature, name, .. } = def {
let ty_str = let ty_str = composer.unifier.internal_stringify(
composer *signature,
.unifier &mut |id| id.to_string(),
.internal_stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None); &mut |id| id.to_string(),
&mut None,
);
assert_eq!(ty_str, tys[i]); assert_eq!(ty_str, tys[i]);
assert_eq!(name, names[i]); assert_eq!(name, names[i]);
} }
@ -779,9 +784,12 @@ impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
type Error = String; type Error = String;
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> { fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
Ok(if let Some(ty) = user { Ok(if let Some(ty) = user {
self.unifier.internal_stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| { self.unifier.internal_stringify(
format!("tvar{}", id.to_string()) ty,
}, &mut None) &mut |id| format!("class{}", id.to_string()),
&mut |id| format!("tvar{}", id.to_string()),
&mut None,
)
} else { } else {
"None".into() "None".into()
}) })

View File

@ -23,17 +23,27 @@ impl TypeAnnotation {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => { CustomClass { id, params } => {
let class_name = match unifier.top_level { let class_name = match unifier.top_level {
Some(ref top) => if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() { Some(ref top) => {
if let TopLevelDef::Class { name, .. } =
&*top.definitions.read()[id.0].read()
{
(*name).into() (*name).into()
} else { } else {
format!("def_{}", id.0) format!("def_{}", id.0)
} }
None => format!("def_{}", id.0) }
None => format!("def_{}", id.0),
}; };
format!("{{class: {}, params: {:?}}}", class_name, params.iter().map(|p| p.stringify(unifier)).collect_vec()) format!(
"{{class: {}, params: {:?}}}",
class_name,
params.iter().map(|p| p.stringify(unifier)).collect_vec()
)
} }
Virtual(ty) | List(ty) => ty.stringify(unifier), Virtual(ty) | List(ty) => ty.stringify(unifier),
Tuple(types) => format!("({:?})", types.iter().map(|p| p.stringify(unifier)).collect_vec()), Tuple(types) => {
format!("({:?})", types.iter().map(|p| p.stringify(unifier)).collect_vec())
}
} }
} }
} }
@ -47,7 +57,9 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap<DefinitionId, Vec<Type>>| { let name_handle = |id: &StrRef,
unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>>| {
if id == &"int32".into() { if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32)) Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() { } else if id == &"int64".into() {
@ -93,11 +105,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
unifier.unify(var, ty).unwrap(); unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty)) Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
Err(format!( Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location))
"`{}` is not a valid type annotation (at {})",
id,
expr.location
))
} }
} else { } else {
Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location))
@ -105,14 +113,15 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
let class_name_handle = let class_name_handle =
|id: &StrRef, slice: &ast::Expr<T>, unifier: &mut Unifier, mut locked: HashMap<DefinitionId, Vec<Type>>| { |id: &StrRef,
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] slice: &ast::Expr<T>,
.contains(id) unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()].contains(id)
{ {
return Err(format!("keywords cannot be class name (at {})", expr.location)); return Err(format!("keywords cannot be class name (at {})", expr.location));
} }
let obj_id = resolver let obj_id = resolver.get_identifier_def(*id)?;
.get_identifier_def(*id)?;
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read { if let Some(def_read) = def_read {
@ -157,9 +166,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars // make sure the result do not contain any type vars
let no_type_var = result let no_type_var =
.iter() result.iter().all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
.all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
if no_type_var { if no_type_var {
result result
} else { } else {
@ -297,8 +305,11 @@ pub fn get_type_from_type_annotation_kinds(
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
let temp = let temp = unifier.get_fresh_var_with_range(
unifier.get_fresh_var_with_range(range.as_slice(), *name, *loc); range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.0, p).is_ok()
} }
}; };
@ -338,7 +349,7 @@ pub fn get_type_from_type_annotation_kinds(
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
fields: tobj_fields, fields: tobj_fields,
params: subst.into(), params: subst,
})) }))
} }
} else { } else {

View File

@ -65,9 +65,11 @@ pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
} }
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F) pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>) where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
{ {
let (id, mut fields, params) = if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) { let (id, mut fields, params) =
if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) {
(*obj_id, fields.clone(), params.clone()) (*obj_id, fields.clone(), params.clone())
} else { } else {
unreachable!() unreachable!()
@ -75,11 +77,7 @@ pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
f(unifier, &mut fields); f(unifier, &mut fields);
unsafe { unsafe {
let unification_table = unifier.get_unification_table(); let unification_table = unifier.get_unification_table();
unification_table.set_value(ty, Rc::new(TypeEnum::TObj { unification_table.set_value(ty, Rc::new(TypeEnum::TObj { obj_id: id, fields, params }));
obj_id: id,
fields,
params,
}));
} }
} }
@ -106,8 +104,7 @@ pub fn impl_binop(
for op in ops { for op in ops {
fields.insert(binop_name(op).into(), { fields.insert(binop_name(op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
ret: ret_ty, ret: ret_ty,
vars: function_vars.clone(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
@ -115,16 +112,14 @@ pub fn impl_binop(
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
}], }],
} })),
)),
false, false,
) )
}); });
fields.insert(binop_assign_name(op).into(), { fields.insert(binop_assign_name(op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
ret: store.none, ret: store.none,
vars: function_vars.clone(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
@ -132,8 +127,7 @@ pub fn impl_binop(
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
}], }],
} })),
)),
false, false,
) )
}); });
@ -141,20 +135,17 @@ pub fn impl_binop(
}); });
} }
pub fn impl_unaryop( pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::Unaryop]) {
unifier: &mut Unifier,
ty: Type,
ret_ty: Type,
ops: &[ast::Unaryop],
) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(op).into(), unaryop_name(op).into(),
( (
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] } ret: ret_ty,
)), vars: HashMap::new(),
args: vec![],
})),
false, false,
), ),
); );
@ -174,8 +165,7 @@ pub fn impl_cmpop(
fields.insert( fields.insert(
comparison_name(op).unwrap().into(), comparison_name(op).unwrap().into(),
( (
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
ret: store.bool, ret: store.bool,
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
@ -183,8 +173,7 @@ pub fn impl_cmpop(
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
}], }],
} })),
)),
false, false,
), ),
); );

View File

@ -1,6 +1,6 @@
mod function_check; mod function_check;
pub mod magic_methods; pub mod magic_methods;
pub mod type_error;
pub mod type_inferencer; pub mod type_inferencer;
pub mod typedef; pub mod typedef;
pub mod type_error;
mod unification_table; mod unification_table;

View File

@ -1,9 +1,9 @@
use std::fmt::Display;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum; use crate::typecheck::typedef::TypeEnum;
use super::typedef::{Type, Unifier, RecordKey}; use super::typedef::{RecordKey, Type, Unifier};
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -53,16 +53,13 @@ impl TypeError {
} }
pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError { pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError {
DisplayTypeError { DisplayTypeError { err: self, unifier }
err: self,
unifier
}
} }
} }
pub struct DisplayTypeError<'a> { pub struct DisplayTypeError<'a> {
pub err: TypeError, pub err: TypeError,
pub unifier: &'a Unifier pub unifier: &'a Unifier,
} }
fn loc_to_str(loc: Option<Location>) -> String { fn loc_to_str(loc: Option<Location>) -> String {
@ -86,11 +83,7 @@ impl<'a> Display for DisplayTypeError<'a> {
UnknownArgName(name) => { UnknownArgName(name) => {
write!(f, "Unknown argument name: {}", name) write!(f, "Unknown argument name: {}", name)
} }
IncorrectArgType { IncorrectArgType { name, expected, got } => {
name,
expected,
got,
} => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!( write!(
@ -98,19 +91,26 @@ impl<'a> Display for DisplayTypeError<'a> {
"Incorrect argument type for {}. Expected {}, but got {}", "Incorrect argument type for {}. Expected {}, but got {}",
name, expected, got name, expected, got
) )
}, }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
let rhs = self.unifier.stringify_with_notes(types.1, &mut notes); let rhs = self.unifier.stringify_with_notes(types.1, &mut notes);
write!( write!(
f, f,
"Unable to unify field {}: Got types {}{} and {}{}", "Unable to unify field {}: Got types {}{} and {}{}",
field, lhs, loc_to_str(loc.0), rhs, loc_to_str(loc.1) field,
lhs,
loc_to_str(loc.0),
rhs,
loc_to_str(loc.1)
) )
} }
IncompatibleRange(t, ts) => { IncompatibleRange(t, ts) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
let ts = ts.iter().map(|t| self.unifier.stringify_with_notes(*t, &mut notes)).collect::<Vec<_>>(); let ts = ts
.iter()
.map(|t| self.unifier.stringify_with_notes(*t, &mut notes))
.collect::<Vec<_>>();
write!(f, "Expected any one of these types: {}, but got {}", ts.join(", "), t) write!(f, "Expected any one of these types: {}, but got {}", ts.join(", "), t)
} }
IncompatibleTypes(t1, t2) => { IncompatibleTypes(t1, t2) => {
@ -119,15 +119,21 @@ impl<'a> Display for DisplayTypeError<'a> {
match (&*type1, &*type2) { match (&*type1, &*type2) {
(TypeEnum::TCall(calls), _) => { (TypeEnum::TCall(calls), _) => {
let loc = self.unifier.calls[calls[0].0].loc; let loc = self.unifier.calls[calls[0].0].loc;
let result = write!(f, "{} is not callable", self.unifier.stringify_with_notes(*t2, &mut notes)); let result = write!(
f,
"{} is not callable",
self.unifier.stringify_with_notes(*t2, &mut notes)
);
if let Some(loc) = loc { if let Some(loc) = loc {
result?; result?;
write!(f, " (in {})", loc)?; write!(f, " (in {})", loc)?;
return Ok(()) return Ok(());
} }
result result
} }
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) if ty1.len() != ty2.len() => { (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
if ty1.len() != ty2.len() =>
{
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {} and {}", t1, t2) write!(f, "Tuple length mismatch: got {} and {}", t1, t2)
@ -152,7 +158,11 @@ impl<'a> Display for DisplayTypeError<'a> {
write!(f, "`{}::{}` field does not exist", t, name) write!(f, "`{}::{}` field does not exist", t, name)
} }
TupleIndexOutOfBounds { index, len } => { TupleIndexOutOfBounds { index, len } => {
write!(f, "Tuple index out of bounds. Got {} but tuple has only {} elements", index, len) write!(
f,
"Tuple index out of bounds. Got {} but tuple has only {} elements",
index, len
)
} }
RequiresTypeAnn => { RequiresTypeAnn => {
write!(f, "Unable to infer virtual object type: Type annotation required") write!(f, "Unable to infer virtual object type: Type annotation required")
@ -174,4 +184,3 @@ impl<'a> Display for DisplayTypeError<'a> {
Ok(()) Ok(())
} }
} }

View File

@ -3,7 +3,7 @@ use std::convert::{From, TryInto};
use std::iter::once; use std::iter::once;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier, RecordField}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId}; use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext};
use itertools::izip; use itertools::izip;
@ -125,7 +125,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
} }
ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment } => { ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment } => {
let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::<Result<Vec<_>, _>>()?; let body = body
.into_iter()
.map(|stmt| self.fold_stmt(stmt))
.collect::<Result<Vec<_>, _>>()?;
let outer_in_handler = self.in_handler; let outer_in_handler = self.in_handler;
let mut exception_handlers = Vec::with_capacity(handlers.len()); let mut exception_handlers = Vec::with_capacity(handlers.len());
self.in_handler = true; self.in_handler = true;
@ -133,23 +136,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
let top_level_defs = self.top_level.definitions.read(); let top_level_defs = self.top_level.definitions.read();
let mut naive_folder = NaiveFolder(); let mut naive_folder = NaiveFolder();
for handler in handlers.into_iter() { for handler in handlers.into_iter() {
let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = handler.node; let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } =
handler.node;
let type_ = if let Some(type_) = type_ { let type_ = if let Some(type_) = type_ {
let typ = self.function_data.resolver.parse_type_annotation( let typ = self.function_data.resolver.parse_type_annotation(
top_level_defs.as_slice(), top_level_defs.as_slice(),
self.unifier, self.unifier,
self.primitives, self.primitives,
&type_ &type_,
)?; )?;
self.virtual_checks.push((typ, self.primitives.exception, handler.location)); self.virtual_checks.push((
typ,
self.primitives.exception,
handler.location,
));
if let Some(name) = name { if let Some(name) = name {
if !self.defined_identifiers.contains(&name) { if !self.defined_identifiers.contains(&name) {
self.defined_identifiers.insert(name); self.defined_identifiers.insert(name);
} }
if let Some(old_typ) = self.variable_mapping.insert(name, typ) { if let Some(old_typ) = self.variable_mapping.insert(name, typ) {
let loc = handler.location; let loc = handler.location;
self.unifier.unify(old_typ, typ).map_err(|e| e.at(Some(loc)) self.unifier.unify(old_typ, typ).map_err(|e| {
.to_display(self.unifier).to_string())?; e.at(Some(loc)).to_display(self.unifier).to_string()
})?;
} }
} }
let mut type_ = naive_folder.fold_expr(*type_)?; let mut type_ = naive_folder.fold_expr(*type_)?;
@ -158,22 +167,32 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} else { } else {
None None
}; };
let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::<Result<Vec<_>, _>>()?; let body = body
.into_iter()
.map(|stmt| self.fold_stmt(stmt))
.collect::<Result<Vec<_>, _>>()?;
exception_handlers.push(Located { exception_handlers.push(Located {
location: handler.location, location: handler.location,
node: ast::ExcepthandlerKind::ExceptHandler { type_, name, body }, node: ast::ExcepthandlerKind::ExceptHandler { type_, name, body },
custom: None custom: None,
}); });
} }
} }
self.in_handler = outer_in_handler; self.in_handler = outer_in_handler;
let handlers = exception_handlers; let handlers = exception_handlers;
let orelse = orelse.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::<Result<Vec<_>, _>>()?; let orelse = orelse.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::<Result<
let finalbody = finalbody .into_iter().map(|stmt| self.fold_stmt(stmt)).collect::<Result<Vec<_>, _>>()?; Vec<_>,
_,
>>(
)?;
let finalbody = finalbody
.into_iter()
.map(|stmt| self.fold_stmt(stmt))
.collect::<Result<Vec<_>, _>>()?;
Located { Located {
location: node.location, location: node.location,
node: ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment }, node: ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment },
custom: None custom: None,
} }
} }
ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => { ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => {
@ -186,14 +205,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
self.unify(list, iter.custom.unwrap(), &iter.location)?; self.unify(list, iter.custom.unwrap(), &iter.location)?;
} }
let body = body let body =
.into_iter() body.into_iter().map(|b| self.fold_stmt(b)).collect::<Result<Vec<_>, _>>()?;
.map(|b| self.fold_stmt(b)) let orelse =
.collect::<Result<Vec<_>, _>>()?; orelse.into_iter().map(|o| self.fold_stmt(o)).collect::<Result<Vec<_>, _>>()?;
let orelse = orelse
.into_iter()
.map(|o| self.fold_stmt(o))
.collect::<Result<Vec<_>, _>>()?;
Located { Located {
location: node.location, location: node.location,
node: ast::StmtKind::For { node: ast::StmtKind::For {
@ -204,7 +219,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
config_comment, config_comment,
type_comment, type_comment,
}, },
custom: None custom: None,
} }
} }
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
@ -252,7 +267,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}) })
.collect(); .collect();
let loc = node.location; let loc = node.location;
let targets = targets.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?; let targets = targets
.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?;
return Ok(Located { return Ok(Located {
location: node.location, location: node.location,
node: ast::StmtKind::Assign { node: ast::StmtKind::Assign {
@ -283,8 +299,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
_ => fold::fold_stmt(self, node)?, _ => fold::fold_stmt(self, node)?,
}; };
match &stmt.node { match &stmt.node {
ast::StmtKind::For { .. } => {}, ast::StmtKind::For { .. } => {}
ast::StmtKind::Try { .. } => {}, ast::StmtKind::Try { .. } => {}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
} }
@ -302,9 +318,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
return report_error("raise ... from cause is not supported", cause.location); return report_error("raise ... from cause is not supported", cause.location);
} }
if let Some(exc) = exc { if let Some(exc) = exc {
self.virtual_checks.push((exc.custom.unwrap(), self.primitives.exception, exc.location)); self.virtual_checks.push((
exc.custom.unwrap(),
self.primitives.exception,
exc.location,
));
} else if !self.in_handler { } else if !self.in_handler {
return report_error("cannot reraise outside exception handlers", stmt.location); return report_error(
"cannot reraise outside exception handlers",
stmt.location,
);
} }
} }
ast::StmtKind::With { items, .. } => { ast::StmtKind::With { items, .. } => {
@ -419,8 +442,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
_ => fold::fold_expr(self, node)?, _ => fold::fold_expr(self, node)?,
}; };
let custom = match &expr.node { let custom = match &expr.node {
ast::ExprKind::Constant { value, .. } => ast::ExprKind::Constant { value, .. } => {
Some(self.infer_constant(value, &expr.location)?), Some(self.infer_constant(value, &expr.location)?)
}
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) { if !self.defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type( match self.function_data.resolver.get_symbol_type(
@ -481,7 +505,9 @@ impl<'a> Inferencer<'a> {
} }
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> {
self.unifier.unify(a, b).map_err(|e| e.at(Some(*location)).to_display(self.unifier).to_string()) self.unifier
.unify(a, b)
.map_err(|e| e.at(Some(*location)).to_display(self.unifier).to_string())
} }
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> { fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> {
@ -533,9 +559,9 @@ impl<'a> Inferencer<'a> {
.map(|v| v.name) .map(|v| v.name)
.rev() .rev()
.collect(); .collect();
self.unifier self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
.unify_call(&call, ty, sign, &required) e.at(Some(location)).to_display(self.unifier).to_string()
.map_err(|e| e.at(Some(location)).to_display(self.unifier).to_string())?; })?;
return Ok(sign.ret); return Ok(sign.ret);
} }
} }
@ -585,8 +611,11 @@ impl<'a> Inferencer<'a> {
defined_identifiers.insert(*name); defined_identifiers.insert(*name);
} }
} }
let fn_args: Vec<_> = let fn_args: Vec<_> = args
args.args.iter().map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)).collect(); .args
.iter()
.map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0))
.collect();
let mut variable_mapping = self.variable_mapping.clone(); let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned()); variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_dummy_var().0; let ret = self.unifier.get_dummy_var().0;
@ -649,7 +678,7 @@ impl<'a> Inferencer<'a> {
calls: self.calls, calls: self.calls,
defined_identifiers, defined_identifiers,
// listcomp expr should not be considered as inside an exception handler... // listcomp expr should not be considered as inside an exception handler...
in_handler: false in_handler: false,
}; };
let generator = generators.pop().unwrap(); let generator = generators.pop().unwrap();
if generator.is_async { if generator.is_async {
@ -784,7 +813,7 @@ impl<'a> Inferencer<'a> {
.collect(), .collect(),
fun: RefCell::new(None), fun: RefCell::new(None),
ret: sign.ret, ret: sign.ret,
loc: Some(location) loc: Some(location),
}; };
let required: Vec<_> = sign let required: Vec<_> = sign
.args .args
@ -813,7 +842,7 @@ impl<'a> Inferencer<'a> {
.collect(), .collect(),
fun: RefCell::new(None), fun: RefCell::new(None),
ret, ret,
loc: Some(location) loc: Some(location),
}); });
self.calls.insert(location.into(), call); self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
@ -853,8 +882,8 @@ impl<'a> Inferencer<'a> {
} else { } else {
report_error("Integer out of bound", *loc) report_error("Integer out of bound", *loc)
} }
}, }
None => report_error("Integer out of bound", *loc) None => report_error("Integer out of bound", *loc),
} }
} }
ast::Constant::Float(_) => Ok(self.primitives.float), ast::Constant::Float(_) => Ok(self.primitives.float),
@ -900,8 +929,11 @@ impl<'a> Inferencer<'a> {
} }
} else { } else {
let attr_ty = self.unifier.get_dummy_var().0; let attr_ty = self.unifier.get_dummy_var().0;
let fields = once((attr.into(), RecordField::new( let fields = once((
attr_ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); attr.into(),
RecordField::new(attr_ty, ctx == &ExprContext::Store, Some(value.location)),
))
.collect();
let record = self.unifier.add_record(fields); let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?; self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty) Ok(attr_ty)
@ -986,8 +1018,11 @@ impl<'a> Inferencer<'a> {
None => None, None => None,
}; };
let ind = ind.ok_or_else(|| "Index must be int32".to_string())?; let ind = ind.ok_or_else(|| "Index must be int32".to_string())?;
let map = once((ind.into(), RecordField::new( let map = once((
ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); ind.into(),
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
))
.collect();
let seq = self.unifier.add_record(map); let seq = self.unifier.add_record(map);
self.constrain(value.custom.unwrap(), seq, &value.location)?; self.constrain(value.custom.unwrap(), seq, &value.location)?;
Ok(ty) Ok(ty)

View File

@ -1,4 +1,4 @@
use super::super::{typedef::*, magic_methods::with_fields}; use super::super::{magic_methods::with_fields, typedef::*};
use super::*; use super::*;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
@ -18,7 +18,10 @@ struct Resolver {
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> { fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!() unimplemented!()
} }
@ -66,54 +69,51 @@ impl TestEnvironment {
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc( let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32, ret: int32,
vars: HashMap::new(), vars: HashMap::new(),
} }));
.into(),
));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: DefinitionId(1),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: DefinitionId(2),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: DefinitionId(4),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: DefinitionId(6),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: DefinitionId(7),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception };
set_primitives_magic_methods(&primitives, &mut unifier); set_primitives_magic_methods(&primitives, &mut unifier);
@ -167,58 +167,56 @@ impl TestEnvironment {
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new(); let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc( let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32, ret: int32,
vars: HashMap::new(), vars: HashMap::new(),
} }));
.into(),
));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: DefinitionId(1),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: DefinitionId(2),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: DefinitionId(4),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: DefinitionId(6),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: DefinitionId(7),
fields: HashMap::new().into(), fields: HashMap::new(),
params: HashMap::new().into(), params: HashMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
["int32", "int64", "float", "bool", "none", "range", "str", "Exception"].iter().enumerate() .iter()
.enumerate()
{ {
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
@ -230,7 +228,7 @@ impl TestEnvironment {
ancestors: Default::default(), ancestors: Default::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None loc: None,
}) })
.into(), .into(),
); );
@ -243,8 +241,8 @@ impl TestEnvironment {
let foo_ty = unifier.add_ty(TypeEnum::TObj { let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1), obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>().into(), fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>().into(), params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
@ -263,26 +261,24 @@ impl TestEnvironment {
identifier_mapping.insert( identifier_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(), vars: [(id, v0)].iter().cloned().collect(),
} })),
.into(),
)),
); );
let fun = unifier.add_ty(TypeEnum::TFunc( let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(), args: vec![],
)); ret: int32,
vars: Default::default(),
}));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2), obj_id: DefinitionId(defs + 2),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))] fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
.iter() .iter()
.cloned() .cloned()
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>(),
.into(),
params: Default::default(), params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
@ -295,15 +291,17 @@ impl TestEnvironment {
ancestors: Default::default(), ancestors: Default::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None loc: None,
}) })
.into(), .into(),
); );
identifier_mapping.insert( identifier_mapping.insert(
"Bar".into(), "Bar".into(),
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(), args: vec![],
)), ret: bar,
vars: Default::default(),
})),
); );
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
@ -311,8 +309,7 @@ impl TestEnvironment {
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))] fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
.iter() .iter()
.cloned() .cloned()
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>(),
.into(),
params: Default::default(), params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
@ -325,15 +322,17 @@ impl TestEnvironment {
ancestors: Default::default(), ancestors: Default::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None loc: None,
}) })
.into(), .into(),
); );
identifier_mapping.insert( identifier_mapping.insert(
"Bar2".into(), "Bar2".into(),
unifier.add_ty(TypeEnum::TFunc( unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(), args: vec![],
)), ret: bar2,
vars: Default::default(),
})),
); );
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
@ -400,7 +399,7 @@ impl TestEnvironment {
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls, calls: &mut self.calls,
defined_identifiers: Default::default(), defined_identifiers: Default::default(),
in_handler: false in_handler: false,
} }
} }
} }
@ -493,7 +492,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
println!("{}: {}", k, name); println!("{}: {}", k, name);
} }
@ -503,7 +502,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
} }
@ -513,13 +512,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
*a, *a,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
let b = inferencer.unifier.internal_stringify( let b = inferencer.unifier.internal_stringify(
*b, *b,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
assert_eq!(&a, x); assert_eq!(&a, x);
@ -639,7 +638,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
println!("{}: {}", k, name); println!("{}: {}", k, name);
} }
@ -649,7 +648,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
&mut None &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
} }

View File

@ -6,10 +6,10 @@ use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{StrRef, Location}; use nac3parser::ast::{Location, StrRef};
use super::unification_table::{UnificationKey, UnificationTable};
use super::type_error::{TypeError, TypeErrorKind}; use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
@ -51,14 +51,14 @@ pub struct FunSignature {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RecordKey { pub enum RecordKey {
Str(StrRef), Str(StrRef),
Int(i32) Int(i32),
} }
impl From<&RecordKey> for StrRef { impl From<&RecordKey> for StrRef {
fn from(r: &RecordKey) -> Self { fn from(r: &RecordKey) -> Self {
match r { match r {
RecordKey::Str(s) => *s, RecordKey::Str(s) => *s,
RecordKey::Int(i) => StrRef::from(i.to_string()) RecordKey::Int(i) => StrRef::from(i.to_string()),
} }
} }
} }
@ -85,7 +85,7 @@ impl Display for RecordKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
RecordKey::Str(s) => write!(f, "{}", s), RecordKey::Str(s) => write!(f, "{}", s),
RecordKey::Int(i) => write!(f, "{}", i) RecordKey::Int(i) => write!(f, "{}", i),
} }
} }
} }
@ -94,7 +94,7 @@ impl Display for RecordKey {
pub struct RecordField { pub struct RecordField {
ty: Type, ty: Type,
mutable: bool, mutable: bool,
loc: Option<Location> loc: Option<Location>,
} }
impl RecordField { impl RecordField {
@ -108,7 +108,7 @@ pub enum TypeEnum {
TRigidVar { TRigidVar {
id: u32, id: u32,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location> loc: Option<Location>,
}, },
TVar { TVar {
id: u32, id: u32,
@ -117,7 +117,7 @@ pub enum TypeEnum {
// empty indicates no restriction // empty indicates no restriction
range: Vec<Type>, range: Vec<Type>,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location> loc: Option<Location>,
}, },
TTuple { TTuple {
ty: Vec<Type>, ty: Vec<Type>,
@ -264,7 +264,11 @@ impl Unifier {
self.unification_table.probe_value_immutable(a).clone() self.unification_table.probe_value_immutable(a).clone()
} }
pub fn get_fresh_rigid_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> (Type, u32) { pub fn get_fresh_rigid_var(
&mut self,
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
(self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id) (self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id)
@ -279,11 +283,16 @@ impl Unifier {
} }
/// Get a fresh type variable. /// Get a fresh type variable.
pub fn get_fresh_var_with_range(&mut self, range: &[Type], name: Option<StrRef>, loc: Option<Location>) -> (Type, u32) { pub fn get_fresh_var_with_range(
&mut self,
range: &[Type],
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc}), id) (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id)
} }
/// Unification would not unify rigid variables with other types, but we want to do this for /// Unification would not unify rigid variables with other types, but we want to do this for
@ -344,8 +353,7 @@ impl Unifier {
.map(|params| { .map(|params| {
self.subst( self.subst(
ty, ty,
&zip(keys.iter().cloned(), params.iter().cloned()) &zip(keys.iter().cloned(), params.iter().cloned()).collect(),
.collect(),
) )
.unwrap_or(ty) .unwrap_or(ty)
}) })
@ -395,23 +403,19 @@ impl Unifier {
// we check to make sure that all required arguments (those without default // we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice. // arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec(); let mut required = required.to_vec();
let mut all_names: Vec<_> = let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
for (i, t) in posargs.iter().enumerate() { for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i { if signature.args.len() <= i {
return Err(TypeError::new(TypeErrorKind::TooManyArguments{ return Err(TypeError::new(
expected: signature.args.len(), TypeErrorKind::TooManyArguments { expected: signature.args.len(), got: i },
got: i, *loc,
}, *loc)); ));
} }
required.pop(); required.pop();
let (name, expected) = all_names.pop().unwrap(); let (name, expected) = all_names.pop().unwrap();
self.unify_impl(expected, *t, false) self.unify_impl(expected, *t, false).map_err(|_| {
.map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
name, })?;
expected,
got: *t,
}, *loc))?;
} }
for (k, t) in kwargs.iter() { for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) { if let Some(i) = required.iter().position(|v| v == k) {
@ -422,18 +426,17 @@ impl Unifier {
.position(|v| &v.0 == k) .position(|v| &v.0 == k)
.ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; .ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?;
let (name, expected) = all_names.remove(i); let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false) self.unify_impl(expected, *t, false).map_err(|_| {
.map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
name, })?;
expected,
got: *t,
}, *loc))?;
} }
if !required.is_empty() { if !required.is_empty() {
return Err(TypeError::new(TypeErrorKind::MissingArgs(required.iter().join(", ")), *loc)); return Err(TypeError::new(
TypeErrorKind::MissingArgs(required.iter().join(", ")),
*loc,
));
} }
self.unify_impl(*ret, signature.ret, false) self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
.map_err(|mut err| {
if err.loc.is_none() { if err.loc.is_none() {
err.loc = *loc; err.loc = *loc;
} }
@ -471,24 +474,38 @@ impl Unifier {
) )
}; };
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
(TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields2, name: name2, loc: loc2, .. }) => { (
TVar { fields: fields1, id, name: name1, loc: loc1, .. },
TVar { fields: fields2, name: name2, loc: loc2, .. },
) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
(None, Some(fields)) => Some(fields.clone()), (None, Some(fields)) => Some(fields.clone()),
(_, None) => { (_, None) => {
return self.unify_impl(b, a, true); return self.unify_impl(b, a, true);
}, }
(Some(fields1), Some(fields2)) => { (Some(fields1), Some(fields2)) => {
let mut new_fields: Mapping<_, _> = fields2.clone(); let mut new_fields: Mapping<_, _> = fields2.clone();
for (key, val1) in fields1.iter() { for (key, val1) in fields1.iter() {
if let Some(val2) = fields2.get(key) { if let Some(val2) = fields2.get(key) {
self.unify_impl(val1.ty, val2.ty, false) self.unify_impl(val1.ty, val2.ty, false).map_err(|_| {
.map_err(|_| TypeError::new(TypeErrorKind::FieldUnificationError { TypeError::new(
TypeErrorKind::FieldUnificationError {
field: *key, field: *key,
types: (val1.ty, val2.ty), types: (val1.ty, val2.ty),
loc: (*loc1, *loc2), loc: (*loc1, *loc2),
}, None))?; },
new_fields.insert(*key, RecordField::new(val1.ty, val1.mutable || val2.mutable, val1.loc.or(val2.loc))); None,
)
})?;
new_fields.insert(
*key,
RecordField::new(
val1.ty,
val1.mutable || val2.mutable,
val1.loc.or(val2.loc),
),
);
} else { } else {
new_fields.insert(*key, *val1); new_fields.insert(*key, *val1);
} }
@ -496,21 +513,26 @@ impl Unifier {
Some(new_fields) Some(new_fields)
} }
}; };
let intersection = self.get_intersection(a, b).map_err(|_| let intersection = self
TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?.unwrap(); .get_intersection(a, b)
.map_err(|_| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?
.unwrap();
let range = if let TypeEnum::TVar { range, .. } = &*self.get_ty(intersection) { let range = if let TypeEnum::TVar { range, .. } = &*self.get_ty(intersection) {
range.clone() range.clone()
} else { } else {
unreachable!() unreachable!()
}; };
self.unification_table.unify(a, b); self.unification_table.unify(a, b);
self.unification_table.set_value(a, Rc::new(TypeEnum::TVar { self.unification_table.set_value(
a,
Rc::new(TypeEnum::TVar {
id: *id, id: *id,
fields: new_fields, fields: new_fields,
range, range,
name: name1.or(*name2), name: name1.or(*name2),
loc: loc1.or(*loc2) loc: loc1.or(*loc2),
})); }),
);
} }
(TVar { fields: None, range, .. }, _) => { (TVar { fields: None, range, .. }, _) => {
// We check for the range of the type variable to see if unification is allowed. // We check for the range of the type variable to see if unification is allowed.
@ -520,8 +542,12 @@ impl Unifier {
// The return value x of check_var_compatibility would be a new type that is // The return value x of check_var_compatibility would be a new type that is
// guaranteed to be compatible with a under all possible instantiations. So we // guaranteed to be compatible with a under all possible instantiations. So we
// unify x with b to recursively apply the constrains, and then set a to x. // unify x with b to recursively apply the constrains, and then set a to x.
let x = self.check_var_compatibility(b, range).map_err(|_| let x = self
TypeError::new(TypeErrorKind::IncompatibleRange(b, range.clone()), None))?.unwrap_or(b); .check_var_compatibility(b, range)
.map_err(|_| {
TypeError::new(TypeErrorKind::IncompatibleRange(b, range.clone()), None)
})?
.unwrap_or(b);
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
@ -532,17 +558,23 @@ impl Unifier {
RecordKey::Int(i) => { RecordKey::Int(i) => {
if v.mutable { if v.mutable {
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::MutationError(*k, b), v.loc)); TypeErrorKind::MutationError(*k, b),
v.loc,
));
} }
let ind = if i < 0 { len + i } else { i }; let ind = if i < 0 { len + i } else { i };
if ind >= len || ind < 0 { if ind >= len || ind < 0 {
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::TupleIndexOutOfBounds{ index: i, len}, v.loc)); TypeErrorKind::TupleIndexOutOfBounds { index: i, len },
v.loc,
));
} }
self.unify_impl(v.ty, ty[ind as usize], false).map_err(|e| e.at(v.loc))?; self.unify_impl(v.ty, ty[ind as usize], false)
.map_err(|e| e.at(v.loc))?;
}
RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
} }
RecordKey::Str(_) => return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b), v.loc)),
} }
} }
let x = self.check_var_compatibility(b, range)?.unwrap_or(b); let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
@ -552,9 +584,12 @@ impl Unifier {
(TVar { fields: Some(fields), range, .. }, TList { ty }) => { (TVar { fields: Some(fields), range, .. }, TList { ty }) => {
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
RecordKey::Int(_) => self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?, RecordKey::Int(_) => {
RecordKey::Str(_) => return Err(TypeError::new( self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?
TypeErrorKind::NoSuchField(*k, b), v.loc)), }
RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
}
} }
} }
let x = self.check_var_compatibility(b, range)?.unwrap_or(b); let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
@ -578,23 +613,26 @@ impl Unifier {
for (k, field) in map.iter() { for (k, field) in map.iter() {
match *k { match *k {
RecordKey::Str(s) => { RecordKey::Str(s) => {
let (ty, mutable) = fields let (ty, mutable) = fields.get(&s).copied().ok_or_else(|| {
.get(&s) TypeError::new(TypeErrorKind::NoSuchField(*k, b), field.loc)
.copied() })?;
.ok_or_else(|| TypeError::new(
TypeErrorKind::NoSuchField(*k, b), field.loc))?;
// typevar represents the usage of the variable // typevar represents the usage of the variable
// it is OK to have immutable usage for mutable fields // it is OK to have immutable usage for mutable fields
// but cannot have mutable usage for immutable fields // but cannot have mutable usage for immutable fields
if field.mutable && !mutable{ if field.mutable && !mutable {
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::MutationError(*k, b), field.loc)); TypeErrorKind::MutationError(*k, b),
field.loc,
));
} }
self.unify_impl(field.ty, ty, false) self.unify_impl(field.ty, ty, false).map_err(|v| v.at(field.loc))?;
.map_err(|v| v.at(field.loc))?; }
RecordKey::Int(_) => {
return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
field.loc,
))
} }
RecordKey::Int(_) => return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b), field.loc))
} }
} }
let x = self.check_var_compatibility(b, range)?.unwrap_or(b); let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
@ -607,29 +645,35 @@ impl Unifier {
for (k, field) in map.iter() { for (k, field) in map.iter() {
match *k { match *k {
RecordKey::Str(s) => { RecordKey::Str(s) => {
let (ty, _) = fields let (ty, _) = fields.get(&s).copied().ok_or_else(|| {
.get(&s) TypeError::new(TypeErrorKind::NoSuchField(*k, b), field.loc)
.copied() })?;
.ok_or_else(|| TypeError::new(
TypeErrorKind::NoSuchField(*k, b), field.loc))?;
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b), field.loc)) TypeErrorKind::NoSuchField(*k, b),
field.loc,
));
} }
if field.mutable { if field.mutable {
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::MutationError(*k, b), field.loc)); TypeErrorKind::MutationError(*k, b),
field.loc,
));
} }
self.unify_impl(field.ty, ty, false) self.unify_impl(field.ty, ty, false)
.map_err(|v| v.at(field.loc))?; .map_err(|v| v.at(field.loc))?;
} }
RecordKey::Int(_) => return Err(TypeError::new( RecordKey::Int(_) => {
TypeErrorKind::NoSuchField(*k, b), field.loc)) return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
field.loc,
))
}
} }
} }
} else { } else {
// require annotation... // require annotation...
return Err(TypeError::new(TypeErrorKind::RequiresTypeAnn, None)) return Err(TypeError::new(TypeErrorKind::RequiresTypeAnn, None));
} }
let x = self.check_var_compatibility(b, range)?.unwrap_or(b); let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
@ -708,7 +752,11 @@ impl Unifier {
self.stringify_with_notes(ty, &mut None) self.stringify_with_notes(ty, &mut None)
} }
pub fn stringify_with_notes(&self, ty: Type, notes: &mut Option<HashMap<u32, String>>) -> String { pub fn stringify_with_notes(
&self,
ty: Type,
notes: &mut Option<HashMap<u32, String>>,
) -> String {
let top_level = self.top_level.clone(); let top_level = self.top_level.clone();
self.internal_stringify( self.internal_stringify(
ty, ty,
@ -727,51 +775,82 @@ impl Unifier {
) )
}, },
&mut |id| format!("var{}", id), &mut |id| format!("var{}", id),
notes notes,
) )
} }
/// Get string representation of the type /// Get string representation of the type
pub fn internal_stringify<F, G>(&self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G, notes: &mut Option<HashMap<u32, String>>) -> String pub fn internal_stringify<F, G>(
&self,
ty: Type,
obj_to_name: &mut F,
var_to_name: &mut G,
notes: &mut Option<HashMap<u32, String>>,
) -> String
where where
F: FnMut(usize) -> String, F: FnMut(usize) -> String,
G: FnMut(u32) -> String, G: FnMut(u32) -> String,
{ {
let ty = self.unification_table.probe_value_immutable(ty).clone(); let ty = self.unification_table.probe_value_immutable(ty).clone();
match ty.as_ref() { match ty.as_ref() {
TypeEnum::TRigidVar { id, name, .. } => name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), TypeEnum::TRigidVar { id, name, .. } => {
name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id))
}
TypeEnum::TVar { id, name, fields, range, .. } => { TypeEnum::TVar { id, name, fields, range, .. } => {
let n = if let Some(fields) = fields { let n = if let Some(fields) = fields {
let mut fields = fields.iter().map(|(k, f)| format!("{}={}", k, self.internal_stringify(f.ty, obj_to_name, var_to_name, notes))); let mut fields = fields.iter().map(|(k, f)| {
format!(
"{}={}",
k,
self.internal_stringify(f.ty, obj_to_name, var_to_name, notes)
)
});
let fields = fields.join(", "); let fields = fields.join(", ");
format!("{}[{}]", name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), fields) format!(
"{}[{}]",
name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)),
fields
)
} else { } else {
name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)) name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id))
}; };
if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id) { if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id)
{
// just in case if there is any cyclic dependency // just in case if there is any cyclic dependency
notes.as_mut().unwrap().insert(*id, "".into()); notes.as_mut().unwrap().insert(*id, "".into());
let body = format!("{}{{{}}}", n, range.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)).collect::<Vec<_>>().join(", ")); let body = format!(
"{} ∈ {{{}}}",
n,
range
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes))
.collect::<Vec<_>>()
.join(", ")
);
notes.as_mut().unwrap().insert(*id, body); notes.as_mut().unwrap().insert(*id, body);
}; };
n n
} }
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut fields = ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", ")) format!("tuple[{}]", fields.join(", "))
} }
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
} }
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
format!("virtual[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) format!(
"virtual[{}]",
self.internal_stringify(*ty, obj_to_name, var_to_name, notes)
)
} }
TypeEnum::TObj { obj_id, params, .. } => { TypeEnum::TObj { obj_id, params, .. } => {
let name = obj_to_name(obj_id.0); let name = obj_to_name(obj_id.0);
if !params.is_empty() { if !params.is_empty() {
let params = params.iter().map(|(_, v)| { let params = params
self.internal_stringify(*v, obj_to_name, var_to_name, notes) .iter()
}); .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
// sort to preserve order // sort to preserve order
let mut params = params.sorted(); let mut params = params.sorted();
format!("{}[{}]", name, params.join(", ")) format!("{}[{}]", name, params.join(", "))
@ -786,9 +865,18 @@ impl Unifier {
.iter() .iter()
.map(|arg| { .map(|arg| {
if let Some(dv) = &arg.default_value { if let Some(dv) = &arg.default_value {
format!("{}:{}={}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), dv) format!(
"{}:{}={}",
arg.name,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv
)
} else { } else {
format!("{}:{}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)) format!(
"{}:{}",
arg.name,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
)
} }
}) })
.join(", "); .join(", ");
@ -834,7 +922,9 @@ impl Unifier {
} else { } else {
let mapping = vars let mapping = vars
.into_iter() .into_iter()
.map(|(k, range, name, loc)| (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0)) .map(|(k, range, name, loc)| {
(k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0)
})
.collect(); .collect();
self.subst(ty, &mapping).unwrap_or(ty) self.subst(ty, &mapping).unwrap_or(ty)
} }
@ -907,9 +997,8 @@ impl Unifier {
let obj_id = *obj_id; let obj_id = *obj_id;
let params = let params =
self.subst_map(params, mapping, cache).unwrap_or_else(|| params.clone()); self.subst_map(params, mapping, cache).unwrap_or_else(|| params.clone());
let fields = self let fields =
.subst_map2(fields, mapping, cache) self.subst_map2(fields, mapping, cache).unwrap_or_else(|| fields.clone());
.unwrap_or_else(|| fields.clone());
let new_ty = self.add_ty(TypeEnum::TObj { obj_id, params, fields }); let new_ty = self.add_ty(TypeEnum::TObj { obj_id, params, fields });
if let Some(var) = cache.get(&a).unwrap() { if let Some(var) = cache.get(&a).unwrap() {
self.unify_impl(new_ty, *var, false).unwrap(); self.unify_impl(new_ty, *var, false).unwrap();
@ -934,7 +1023,7 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone()); let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret); let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.into_owned(); let args = new_args.into_owned();
Some( self.add_ty(TypeEnum::TFunc( FunSignature { args, ret, vars: params })),) Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params })))
} else { } else {
None None
} }
@ -992,7 +1081,10 @@ impl Unifier {
let x = self.get_ty(a); let x = self.get_ty(a);
let y = self.get_ty(b); let y = self.get_ty(b);
match (x.as_ref(), y.as_ref()) { match (x.as_ref(), y.as_ref()) {
(TVar { range: range1, name, loc, .. }, TVar { fields, range: range2, name: name2, loc: loc2, .. }) => { (
TVar { range: range1, name, loc, .. },
TVar { fields, range: range2, name: name2, loc: loc2, .. },
) => {
// new range is the intersection of them // new range is the intersection of them
// empty range indicates no constraint // empty range indicates no constraint
if range1.is_empty() { if range1.is_empty() {
@ -1000,14 +1092,25 @@ impl Unifier {
} else if range2.is_empty() { } else if range2.is_empty() {
Ok(Some(a)) Ok(Some(a))
} else { } else {
let range = range2.iter().cartesian_product(range1.iter()) let range = range2
.filter_map(|(v1, v2)| self.get_intersection(*v1, *v2).map(|v| v.unwrap_or(*v1)).ok()).collect_vec(); .iter()
.cartesian_product(range1.iter())
.filter_map(|(v1, v2)| {
self.get_intersection(*v1, *v2).map(|v| v.unwrap_or(*v1)).ok()
})
.collect_vec();
if range.is_empty() { if range.is_empty() {
Err(()) Err(())
} else { } else {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let ty = TVar { id, fields: fields.clone(), range, name: name2.or(*name), loc: loc2.or(*loc) }; let ty = TVar {
id,
fields: fields.clone(),
range,
name: name2.or(*name),
loc: loc2.or(*loc),
};
Ok(Some(self.unification_table.new_key(ty.into()))) Ok(Some(self.unification_table.new_key(ty.into())))
} }
} }
@ -1026,13 +1129,15 @@ impl Unifier {
Err(()) Err(())
} }
} }
(TVar { range, .. }, _) => { (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
self.check_var_compatibility(b, range).or(Err(()))
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => {
let ty: Vec<_> = zip(ty1.iter(), ty2.iter()).map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?; let ty: Vec<_> = zip(ty1.iter(), ty2.iter())
.map(|(a, b)| self.get_intersection(*a, *b))
.try_collect()?;
if ty.iter().any(Option::is_some) { if ty.iter().any(Option::is_some) {
Ok(Some(self.add_ty(TTuple { ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect()}))) Ok(Some(self.add_ty(TTuple {
ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
})))
} else { } else {
Ok(None) Ok(None)
} }
@ -1043,9 +1148,7 @@ impl Unifier {
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
} }
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) if id1 == id2 => { (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) if id1 == id2 => Ok(None),
Ok(None)
}
// don't deal with function shape for now // don't deal with function shape for now
_ => Err(()), _ => Err(()),
} }

View File

@ -1,5 +1,5 @@
use super::*;
use super::super::magic_methods::with_fields; use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use std::collections::HashMap; use std::collections::HashMap;
@ -115,10 +115,7 @@ impl TestEnvironment {
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: [("a".into(), (v0, true))] fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
.iter()
.cloned()
.collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(), params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
}), }),
); );
@ -365,9 +362,11 @@ fn test_recursive_subst() {
fn test_virtual() { fn test_virtual() {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new()); let int = env.parse("int", &HashMap::new());
let fun = env.unifier.add_ty(TypeEnum::TFunc( let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
FunSignature { args: vec![], ret: int, vars: HashMap::new() }, args: vec![],
)); ret: int,
vars: HashMap::new(),
}));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
@ -381,15 +380,21 @@ fn test_virtual() {
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect()); let c = env
.unifier
.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap(); env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun)); assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); let d = env
.unifier
.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field does not exist".to_string()));
let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); let d = env
.unifier
.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field does not exist".to_string()));
} }
@ -451,10 +456,7 @@ fn test_typevar_range() {
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
assert_eq!( assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
env.unify(a, int),
Err("Expected any one of these types: 1, but got 0".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
@ -556,9 +558,12 @@ fn test_instantiation() {
let types = types let types = types
.iter() .iter()
.map(|ty| { .map(|ty| {
env.unifier.internal_stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { env.unifier.internal_stringify(
format!("v{}", i) *ty,
}, &mut None) &mut |i| obj_map.get(&i).unwrap().to_string(),
&mut |i| format!("v{}", i),
&mut None,
)
}) })
.sorted() .sorted()
.collect_vec(); .collect_vec();

View File

@ -1,4 +1,5 @@
mod cslice { // copied from https://github.com/dherman/cslice mod cslice {
// copied from https://github.com/dherman/cslice
use std::marker::PhantomData; use std::marker::PhantomData;
use std::slice; use std::slice;
@ -7,14 +8,12 @@ mod cslice { // copied from https://github.com/dherman/cslice
pub struct CSlice<'a, T> { pub struct CSlice<'a, T> {
base: *const T, base: *const T,
len: usize, len: usize,
marker: PhantomData<&'a ()> marker: PhantomData<&'a ()>,
} }
impl<'a, T> AsRef<[T]> for CSlice<'a, T> { impl<'a, T> AsRef<[T]> for CSlice<'a, T> {
fn as_ref(&self) -> &[T] { fn as_ref(&self) -> &[T] {
unsafe { unsafe { slice::from_raw_parts(self.base, self.len) }
slice::from_raw_parts(self.base, self.len)
}
} }
} }
} }
@ -58,7 +57,6 @@ pub extern "C" fn __artiq_personality(_state: u32, _exception_object: u32, _cont
unimplemented!(); unimplemented!();
} }
extern "C" { extern "C" {
fn run() -> i32; fn run() -> i32;
} }

View File

@ -38,10 +38,8 @@ pub struct Resolver(pub Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, expr: &ast::Expr) -> Option<SymbolValue> { fn get_default_param_value(&self, expr: &ast::Expr) -> Option<SymbolValue> {
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => self.0.module_globals.lock().get(id).cloned(),
self.0.module_globals.lock().get(id).cloned() _ => unimplemented!("other type of expr not supported at {}", expr.location),
}
_ => unimplemented!("other type of expr not supported at {}", expr.location)
} }
} }

View File

@ -1,24 +1,30 @@
use inkwell::{ use inkwell::{
memory_buffer::MemoryBuffer,
passes::{PassManager, PassManagerBuilder}, passes::{PassManager, PassManagerBuilder},
targets::*, targets::*,
OptimizationLevel, memory_buffer::MemoryBuffer, OptimizationLevel,
}; };
use parking_lot::{Mutex, RwLock};
use std::{borrow::Borrow, collections::HashMap, env, fs, path::Path, sync::Arc}; use std::{borrow::Borrow, collections::HashMap, env, fs, path::Path, sync::Arc};
use parking_lot::{RwLock, Mutex};
use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
concrete_type::ConcreteTypeStore, CodeGenTask, DefaultCodeGenerator, WithCall, concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenTask, DefaultCodeGenerator,
WorkerRegistry, irrt::load_irrt, WithCall, WorkerRegistry,
}, },
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
composer::TopLevelComposer, composer::TopLevelComposer, helper::parse_parameter_default_value, type_annotation::*,
TopLevelDef, helper::parse_parameter_default_value, TopLevelDef,
type_annotation::*,
}, },
typecheck::{type_inferencer::PrimitiveStore, typedef::{Type, Unifier, FunSignature}} typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, Unifier},
},
};
use nac3parser::{
ast::{Expr, ExprKind, StmtKind},
parser,
}; };
mod basic_symbol_resolver; mod basic_symbol_resolver;
@ -26,10 +32,7 @@ use basic_symbol_resolver::*;
fn main() { fn main() {
let file_name = env::args().nth(1).unwrap(); let file_name = env::args().nth(1).unwrap();
let threads: u32 = env::args() let threads: u32 = env::args().nth(2).map(|s| str::parse(&s).unwrap()).unwrap_or(1);
.nth(2)
.map(|s| str::parse(&s).unwrap())
.unwrap_or(1);
Target::initialize_all(&InitializationConfig::default()); Target::initialize_all(&InitializationConfig::default());
@ -42,10 +45,8 @@ fn main() {
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
let (mut composer, builtins_def, builtins_ty) = TopLevelComposer::new( let (mut composer, builtins_def, builtins_ty) =
vec![], TopLevelComposer::new(vec![], Default::default());
Default::default()
);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal { let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
@ -83,15 +84,23 @@ fn main() {
x, x,
Default::default(), Default::default(),
)?; )?;
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty) get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0)
} else { } else {
Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) Err(format!(
"expression {:?} cannot be handled as a TypeVar in global scope",
var
))
} }
} else { } else {
Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) Err(format!(
"expression {:?} cannot be handled as a TypeVar in global scope",
var
))
} }
} }
@ -116,7 +125,9 @@ fn main() {
) { ) {
internal_resolver.add_id_type(*id, var); internal_resolver.add_id_type(*id, var);
Ok(()) Ok(())
} else if let Ok(val) = parse_parameter_default_value(value.borrow(), resolver) { } else if let Ok(val) =
parse_parameter_default_value(value.borrow(), resolver)
{
internal_resolver.add_module_global(*id, val); internal_resolver.add_module_global(*id, val);
Ok(()) Ok(())
} else { } else {
@ -126,8 +137,7 @@ fn main() {
)) ))
} }
} }
ExprKind::List { elts, .. } ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
| ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern( handle_assignment_pattern(
elts, elts,
value, value,
@ -135,16 +145,18 @@ fn main() {
internal_resolver, internal_resolver,
def_list, def_list,
unifier, unifier,
primitives primitives,
)?; )?;
Ok(()) Ok(())
} }
_ => Err(format!("assignment to {:?} is not supported at {}", targets[0], targets[0].location)) _ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
} }
} else { } else {
match &value.node { match &value.node {
ExprKind::List { elts, .. } ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
| ExprKind::Tuple { elts, .. } => {
if elts.len() != targets.len() { if elts.len() != targets.len() {
Err(format!( Err(format!(
"number of elements to unpack does not match (expect {}, found {}) at {}", "number of elements to unpack does not match (expect {}, found {}) at {}",
@ -161,13 +173,16 @@ fn main() {
internal_resolver, internal_resolver,
def_list, def_list,
unifier, unifier,
primitives primitives,
)?; )?;
} }
Ok(()) Ok(())
} }
}, }
_ => Err(format!("unpack of this expression is not supported at {}", value.location)) _ => Err(format!(
"unpack of this expression is not supported at {}",
value.location
)),
} }
} }
} }
@ -190,9 +205,8 @@ fn main() {
continue; continue;
} }
let (name, def_id, ty) = composer let (name, def_id, ty) =
.register_top_level(stmt, Some(resolver.clone()), "__main__".into()) composer.register_top_level(stmt, Some(resolver.clone()), "__main__".into()).unwrap();
.unwrap();
internal_resolver.add_id_def(name, def_id); internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
@ -200,11 +214,7 @@ fn main() {
} }
} }
let signature = FunSignature { let signature = FunSignature { args: vec![], ret: primitive.int32, vars: HashMap::new() };
args: vec![],
ret: primitive.int32,
vars: HashMap::new(),
};
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
@ -216,17 +226,12 @@ fn main() {
let instance = { let instance = {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut instance = let mut instance = defs[resolver
defs[resolver
.get_identifier_def("run".into()) .get_identifier_def("run".into())
.unwrap_or_else(|_| panic!("cannot find run() entry point")).0 .unwrap_or_else(|_| panic!("cannot find run() entry point"))
].write(); .0]
if let TopLevelDef::Function { .write();
instance_to_stmt, if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance {
instance_to_symbol,
..
} = &mut *instance
{
instance_to_symbol.insert("".to_string(), "run".to_string()); instance_to_symbol.insert("".to_string(), "run".to_string());
instance_to_stmt[""].clone() instance_to_stmt[""].clone()
} else { } else {
@ -291,8 +296,7 @@ fn main() {
passes.run_on(&main); passes.run_on(&main);
let triple = TargetMachine::get_default_triple(); let triple = TargetMachine::get_default_triple();
let target = let target = Target::from_triple(&triple).expect("couldn't create target from target triple");
Target::from_triple(&triple).expect("couldn't create target from target triple");
let target_machine = target let target_machine = target
.create_target_machine( .create_target_machine(
&triple, &triple,
@ -304,10 +308,6 @@ fn main() {
) )
.expect("couldn't create target machine"); .expect("couldn't create target machine");
target_machine target_machine
.write_to_file( .write_to_file(&main, FileType::Object, Path::new("module.o"))
&main,
FileType::Object,
Path::new("module.o"),
)
.expect("couldn't write module to file"); .expect("couldn't write module to file");
} }