core: Apply clippy pedantic changes

This commit is contained in:
David Mak 2023-12-08 17:43:32 +08:00
parent a1f244834f
commit 02933753ca
20 changed files with 644 additions and 655 deletions

View File

@ -374,7 +374,7 @@ impl Nac3 {
}); });
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone(), false) .register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| { .map_err(|e| {
CompileError::new_err(format!( CompileError::new_err(format!(
"compilation failed\n----------\n{}", "compilation failed\n----------\n{}",
@ -596,7 +596,7 @@ impl Nac3 {
threads, threads,
top_level.clone(), top_level.clone(),
&self.llvm_options, &self.llvm_options,
f &f
); );
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);

View File

@ -9,15 +9,11 @@ use std::{
fn main() { fn main() {
const FILE: &str = "src/codegen/irrt/irrt.c"; const FILE: &str = "src/codegen/irrt/irrt.c";
println!("cargo:rerun-if-changed={}", FILE);
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
const FLAG: &[&str] = &[ const FLAG: &[&str] = &[
"--target=wasm32", "--target=wasm32",
FILE, FILE,
@ -29,6 +25,11 @@ fn main() {
"-o", "-o",
"-", "-",
]; ];
println!("cargo:rerun-if-changed={FILE}");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
let output = Command::new("clang-irrt") let output = Command::new("clang-irrt")
.args(FLAG) .args(FLAG)
.output() .output()
@ -68,5 +69,5 @@ fn main() {
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()) assert!(llvm_as.wait().unwrap().success());
} }

View File

@ -63,6 +63,7 @@ pub enum ConcreteTypeEnum {
} }
impl ConcreteTypeStore { impl ConcreteTypeStore {
#[must_use]
pub fn new() -> ConcreteTypeStore { pub fn new() -> ConcreteTypeStore {
ConcreteTypeStore { ConcreteTypeStore {
store: vec![ store: vec![
@ -80,6 +81,7 @@ impl ConcreteTypeStore {
} }
} }
#[must_use]
pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum { pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum {
&self.store[cty.0] &self.store[cty.0]
} }

View File

@ -47,7 +47,7 @@ pub fn get_subst_key(
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun_vars.iter()); vars.extend(fun_vars.iter());
let sorted = vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted();
sorted sorted
.map(|id| { .map(|id| {
unifier.internal_stringify( unifier.internal_stringify(
@ -119,7 +119,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
SymbolValue::Tuple(ls) => { SymbolValue::Tuple(ls) => {
let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec();
let fields = vals.iter().map(|v| v.get_type()).collect_vec(); let fields = vals.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&fields, false); let ty = self.ctx.struct_type(&fields, false);
let ptr = gen_var(self, ty.into(), Some("tuple")).unwrap(); let ptr = gen_var(self, ty.into(), Some("tuple")).unwrap();
let zero = self.ctx.i32_type().const_zero(); let zero = self.ctx.i32_type().const_zero();
@ -165,7 +165,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
/// See [get_llvm_type]. /// See [`get_llvm_type`].
pub fn get_llvm_type( pub fn get_llvm_type(
&mut self, &mut self,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
@ -183,7 +183,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) )
} }
/// See [get_llvm_abi_type]. /// See [`get_llvm_abi_type`].
pub fn get_llvm_abi_type( pub fn get_llvm_abi_type(
&mut self, &mut self,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
@ -212,7 +212,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
Constant::Bool(v) => { Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.primitives.bool)); assert!(self.unifier.unioned(ty, self.primitives.bool));
let ty = self.ctx.i8_type(); let ty = self.ctx.i8_type();
Some(ty.const_int(if *v { 1 } else { 0 }, false).into()) Some(ty.const_int(u64::from(*v), false).into())
} }
Constant::Int(val) => { Constant::Int(val) => {
let ty = if self.unifier.unioned(ty, self.primitives.int32) let ty = if self.unifier.unioned(ty, self.primitives.int32)
@ -290,12 +290,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
rhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>,
signed: bool signed: bool
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let (lhs, rhs) = let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else {
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) { unreachable!()
(lhs, rhs) };
} else {
unreachable!()
};
let float = self.ctx.f64_type(); let float = self.ctx.f64_type();
match (op, signed) { match (op, signed) {
(Operator::Add, _) => self.builder.build_int_add(lhs, rhs, "add").into(), (Operator::Add, _) => self.builder.build_int_add(lhs, rhs, "add").into(),
@ -318,7 +315,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
(Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").into(), (Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").into(),
// Sign-ness of bitshift operators are always determined by the left operand // Sign-ness of bitshift operators are always determined by the left operand
(Operator::LShift, signed) | (Operator::RShift, signed) => { (Operator::LShift | Operator::RShift, signed) => {
// RHS operand is always 32 bits // RHS operand is always 32 bits
assert_eq!(rhs.get_type().get_bit_width(), 32); assert_eq!(rhs.get_type().get_bit_width(), 32);
@ -365,11 +362,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
lhs: BasicValueEnum<'ctx>, lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
(lhs, rhs)
{
(lhs, rhs)
} else {
unreachable!() unreachable!()
}; };
let float = self.ctx.f64_type(); let float = self.ctx.f64_type();
@ -474,7 +467,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.collect_vec(); .collect_vec();
let result = if let Some(target) = self.unwind_target { let result = if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name)); let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
let result = self let result = self
.builder .builder
.build_invoke(fun, &params, then_block, target, call_name) .build_invoke(fun, &params, then_block, target, call_name)
@ -516,7 +509,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type();
let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into();
let zelf = generator.gen_var_alloc(self, zelf_ty, Some("exn")).unwrap(); let zelf = generator.gen_var_alloc(self, zelf_ty, Some("exn")).unwrap();
self.exception_val.insert(zelf).to_owned() *self.exception_val.insert(zelf)
}; };
let int32 = self.ctx.i32_type(); let int32 = self.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
@ -556,7 +549,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
loc: Location, loc: Location,
) { ) {
let err_msg = self.gen_string(generator, err_msg); let err_msg = self.gen_string(generator, err_msg);
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc) self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
} }
pub fn make_assert_impl( pub fn make_assert_impl(
@ -598,7 +591,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
/// See [CodeGenerator::gen_constructor]. /// See [`CodeGenerator::gen_constructor`].
pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -626,14 +619,14 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
} }
Ok(zelf) Ok(zelf)
} }
_ => unreachable!(), TopLevelDef::Function { .. } => unreachable!(),
} }
} }
/// See [CodeGenerator::gen_func_instance]. /// See [`CodeGenerator::gen_func_instance`].
pub fn gen_func_instance<'ctx>( pub fn gen_func_instance<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, &mut TopLevelDef, String), fun: (&FunSignature, &mut TopLevelDef, String),
id: usize, id: usize,
) -> Result<String, String> { ) -> Result<String, String> {
@ -683,7 +676,7 @@ pub fn gen_func_instance<'ctx>(
args.insert( args.insert(
0, 0,
ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None },
) );
} else { } else {
unreachable!() unreachable!()
} }
@ -707,7 +700,7 @@ pub fn gen_func_instance<'ctx>(
} }
} }
/// See [CodeGenerator::gen_call]. /// See [`CodeGenerator::gen_call`].
pub fn gen_call<'ctx, G: CodeGenerator>( pub fn gen_call<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -738,11 +731,11 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None); let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
let mut keys = fun.0.args.clone(); let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for (key, value) in params.into_iter() { for (key, value) in params {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
} }
// default value handling // default value handling
for k in keys.into_iter() { for k in keys {
if mapping.get(&k.name).is_some() { if mapping.get(&k.name).is_some() {
continue; continue;
} }
@ -774,27 +767,26 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
.map(|(i, v)| (*i, v.get_unique_identifier())) .map(|(i, v)| (*i, v.get_unique_identifier()))
.collect_vec(); .collect_vec();
let mut store = ctx.static_value_store.lock(); let mut store = ctx.static_value_store.lock();
match store.lookup.get(&ids) { if let Some(index) = store.lookup.get(&ids) {
Some(index) => *index, *index
None => { } else {
let length = store.store.len(); let length = store.store.len();
store.lookup.insert(ids, length); store.lookup.insert(ids, length);
store.store.push(static_params.into_iter().collect()); store.store.push(static_params.into_iter().collect());
length length
}
} }
}; };
// special case: extern functions // special case: extern functions
key = if instance_to_stmt.is_empty() { key = if instance_to_stmt.is_empty() {
"".to_string() String::new()
} else { } else {
format!("{}:{}", id, old_key) format!("{id}:{old_key}")
}; };
param_vals = real_params param_vals = real_params
.into_iter() .into_iter()
.map(|(p, t)| p.to_basic_value_enum(ctx, generator, t)) .map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
.collect::<Result<Vec<_>, String>>()?; .collect::<Result<Vec<_>, String>>()?;
instance_to_symbol.get(&key).cloned().ok_or_else(|| "".into()) instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
} }
TopLevelDef::Class { .. } => { TopLevelDef::Class { .. } => {
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
@ -900,7 +892,7 @@ pub fn destructure_range<'ctx>(
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting /// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. /// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
/// ///
/// Returns an instance of [PointerValue] pointing to the List structure. The List structure is /// Returns an instance of [`PointerValue`] pointing to the List structure. The List structure is
/// defined as `type { ty*, size_t }` in LLVM, where the first element stores the pointer to the /// defined as `type { ty*, size_t }` in LLVM, where the first element stores the pointer to the
/// data, and the second element stores the size of the List. /// data, and the second element stores the size of the List.
pub fn allocate_list<'ctx, G: CodeGenerator>( pub fn allocate_list<'ctx, G: CodeGenerator>(
@ -1083,7 +1075,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index"));
}; };
for cond in ifs.iter() { for cond in ifs {
let result = if let Some(v) = generator.gen_expr(ctx, cond)? { let result = if let Some(v) = generator.gen_expr(ctx, cond)? {
v.to_basic_value_enum(ctx, generator, cond.custom.unwrap())?.into_int_value() v.to_basic_value_enum(ctx, generator, cond.custom.unwrap())?.into_int_value()
} else { } else {
@ -1226,11 +1218,11 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
Some((left.custom.unwrap(), left_val.into())), Some((left.custom.unwrap(), left_val.into())),
(&signature, fun_id), (&signature, fun_id),
vec![(None, right_val.into())], vec![(None, right_val.into())],
).map(|f| f.map(|f| f.into())) ).map(|f| f.map(Into::into))
} }
} }
/// See [CodeGenerator::gen_expr]. /// See [`CodeGenerator::gen_expr`].
pub fn gen_expr<'ctx, G: CodeGenerator>( pub fn gen_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1462,7 +1454,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").into(), ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").into(),
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").into(), ast::Unaryop::Invert => ctx.builder.build_not(val, "not").into(),
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").into(), ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").into(),
_ => val.into(), ast::Unaryop::UAdd => val.into(),
} }
} else if ty == ctx.primitives.float { } else if ty == ctx.primitives.float {
let val = val.into_float_value(); let val = val.into_float_value();
@ -1574,11 +1566,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let test = generator.bool_to_i1(ctx, test); let test = generator.bool_to_i1(ctx, test);
let body_ty = body.custom.unwrap(); let body_ty = body.custom.unwrap();
let is_none = ctx.unifier.get_representative(body_ty) == ctx.primitives.none; let is_none = ctx.unifier.get_representative(body_ty) == ctx.primitives.none;
let result = if !is_none { let result = if is_none {
None
} else {
let llvm_ty = ctx.get_llvm_type(generator, body_ty); let llvm_ty = ctx.get_llvm_type(generator, body_ty);
Some(ctx.builder.build_alloca(llvm_ty, "if_exp_result")) Some(ctx.builder.build_alloca(llvm_ty, "if_exp_result"))
} else {
None
}; };
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = ctx.ctx.append_basic_block(current, "then"); let then_bb = ctx.ctx.append_basic_block(current, "then");
@ -1640,15 +1632,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let kw_iter = kw_iter.collect::<Result<Vec<_>, _>>()?; let kw_iter = kw_iter.collect::<Result<Vec<_>, _>>()?;
params.extend(kw_iter); params.extend(kw_iter);
let call = ctx.calls.get(&expr.location.into()); let call = ctx.calls.get(&expr.location.into());
let signature = match call { let signature = if let Some(call) = call {
Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), ctx.unifier.get_call_signature(*call).unwrap()
None => { } else {
let ty = func.custom.unwrap(); let ty = func.custom.unwrap();
if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) { if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) {
sign.clone() sign.clone()
} else { } else {
unreachable!() unreachable!()
}
} }
}; };
let func = func.as_ref(); let func = func.as_ref();
@ -1661,12 +1652,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.map_err(|e| format!("{} (at {})", e, func.location))?; .map_err(|e| format!("{} (at {})", e, func.location))?;
return Ok(generator return Ok(generator
.gen_call(ctx, None, (&signature, fun), params)? .gen_call(ctx, None, (&signature, fun), params)?
.map(|v| v.into())); .map(Into::into));
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let val = match generator.gen_expr(ctx, value)? { let Some(val) = generator.gen_expr(ctx, value)? else {
Some(v) => v, return Ok(None)
None => return Ok(None),
}; };
let id = if let TypeEnum::TObj { obj_id, .. } = let id = if let TypeEnum::TObj { obj_id, .. } =
@ -1745,7 +1735,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
"unwrap_some_load" "unwrap_some_load"
).into())) ).into()))
} }
_ => unreachable!("option must be static or ptr") ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr")
} }
} }
@ -1759,7 +1749,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
(&signature, fun_id), (&signature, fun_id),
params, params,
)? )?
.map(|v| v.into())); .map(Into::into));
} }
_ => unimplemented!(), _ => unimplemented!(),
} }
@ -1871,14 +1861,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ctx.builder.build_extract_value(v, index, "tup_elem").unwrap().into() ctx.builder.build_extract_value(v, index, "tup_elem").unwrap().into()
} }
Some(ValueEnum::Static(v)) => { Some(ValueEnum::Static(v)) => {
match v.get_tuple_element(index) { if let Some(v) = v.get_tuple_element(index) {
Some(v) => v, v
None => { } else {
let tup = v let tup = v
.to_basic_value_enum(ctx, generator, value.custom.unwrap())? .to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_struct_value(); .into_struct_value();
ctx.builder.build_extract_value(tup, index, "tup_elem").unwrap().into() ctx.builder.build_extract_value(tup, index, "tup_elem").unwrap().into()
}
} }
} }
None => return Ok(None), None => return Ok(None),

View File

@ -66,7 +66,7 @@ pub trait CodeGenerator {
fun: (&FunSignature, &mut TopLevelDef, String), fun: (&FunSignature, &mut TopLevelDef, String),
id: usize, id: usize,
) -> Result<String, String> { ) -> Result<String, String> {
gen_func_instance(ctx, obj, fun, id) gen_func_instance(ctx, &obj, fun, id)
} }
/// Generate the code for an expression. /// Generate the code for an expression.
@ -194,7 +194,7 @@ pub trait CodeGenerator {
gen_block(self, ctx, stmts) gen_block(self, ctx, stmts)
} }
/// See [bool_to_i1]. /// See [`bool_to_i1`].
fn bool_to_i1<'ctx>( fn bool_to_i1<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -203,7 +203,7 @@ pub trait CodeGenerator {
bool_to_i1(&ctx.builder, bool_value) bool_to_i1(&ctx.builder, bool_value)
} }
/// See [bool_to_i8]. /// See [`bool_to_i8`].
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -219,6 +219,7 @@ pub struct DefaultCodeGenerator {
} }
impl DefaultCodeGenerator { impl DefaultCodeGenerator {
#[must_use]
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator { pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(matches!(size_t, 32 | 64)); assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t } DefaultCodeGenerator { name, size_t }
@ -227,7 +228,7 @@ impl DefaultCodeGenerator {
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [CodeGenerator]. /// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name
} }

View File

@ -12,6 +12,7 @@ use inkwell::{
}; };
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
@ -288,7 +289,7 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
} }
/// This function handles 'end' **inclusively**. /// This function handles 'end' **inclusively**.
/// Order of tuples assign_idx and value_idx is ('start', 'end', 'step'). /// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function /// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx>( pub fn list_slice_assignment<'ctx>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,

View File

@ -81,8 +81,9 @@ pub struct CodeGenTargetMachineOptions {
impl CodeGenTargetMachineOptions { impl CodeGenTargetMachineOptions {
/// Creates an instance of [CodeGenTargetMachineOptions] using the triple of the host machine. /// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults. /// Other options are set to defaults.
#[must_use]
pub fn from_host_triple() -> CodeGenTargetMachineOptions { pub fn from_host_triple() -> CodeGenTargetMachineOptions {
CodeGenTargetMachineOptions { CodeGenTargetMachineOptions {
triple: TargetMachine::get_default_triple().as_str().to_string_lossy().into_owned(), triple: TargetMachine::get_default_triple().as_str().to_string_lossy().into_owned(),
@ -93,8 +94,9 @@ impl CodeGenTargetMachineOptions {
} }
} }
/// Creates an instance of [CodeGenTargetMachineOptions] using the properties of the host /// Creates an instance of [`CodeGenTargetMachineOptions`] using the properties of the host
/// machine. Other options are set to defaults. /// machine. Other options are set to defaults.
#[must_use]
pub fn from_host() -> CodeGenTargetMachineOptions { pub fn from_host() -> CodeGenTargetMachineOptions {
CodeGenTargetMachineOptions { CodeGenTargetMachineOptions {
cpu: TargetMachine::get_host_cpu_name().to_string(), cpu: TargetMachine::get_host_cpu_name().to_string(),
@ -103,9 +105,10 @@ impl CodeGenTargetMachineOptions {
} }
} }
/// Creates a [TargetMachine] using the target options specified by this struct. /// Creates a [`TargetMachine`] using the target options specified by this struct.
/// ///
/// See [Target::create_target_machine]. /// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine( pub fn create_target_machine(
&self, &self,
level: OptimizationLevel, level: OptimizationLevel,
@ -195,7 +198,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator]. /// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool { pub fn is_terminated(&self) -> bool {
self.builder.get_insert_block().and_then(|bb| bb.get_terminator()).is_some() self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some()
} }
} }
@ -206,12 +209,13 @@ pub struct WithCall {
} }
impl WithCall { impl WithCall {
#[must_use]
pub fn new(fp: Fp) -> WithCall { pub fn new(fp: Fp) -> WithCall {
WithCall { fp } WithCall { fp }
} }
pub fn run(&self, m: &Module) { pub fn run(&self, m: &Module) {
(self.fp)(m) (self.fp)(m);
} }
} }
@ -238,20 +242,21 @@ pub struct WorkerRegistry {
impl WorkerRegistry { impl WorkerRegistry {
/// Creates workers for this registry. /// Creates workers for this registry.
#[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>( pub fn create_workers<G: CodeGenerator + Send + 'static>(
generators: Vec<Box<G>>, generators: Vec<Box<G>>,
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
llvm_options: &CodeGenLLVMOptions, llvm_options: &CodeGenLLVMOptions,
f: Arc<WithCall>, f: &Arc<WithCall>,
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) { ) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded(); let (sender, receiver) = unbounded();
let task_count = Mutex::new(0); let task_count = Mutex::new(0);
let wait_condvar = Condvar::new(); let wait_condvar = Condvar::new();
// init: 0 to be empty // init: 0 to be empty
let mut static_value_store: StaticValueStore = Default::default(); let mut static_value_store = StaticValueStore::default();
static_value_store.lookup.insert(Default::default(), 0); static_value_store.lookup.insert(Vec::default(), 0);
static_value_store.store.push(Default::default()); static_value_store.store.push(HashMap::default());
let registry = Arc::new(WorkerRegistry { let registry = Arc::new(WorkerRegistry {
sender: Arc::new(sender), sender: Arc::new(sender),
@ -266,19 +271,19 @@ impl WorkerRegistry {
}); });
let mut handles = Vec::new(); let mut handles = Vec::new();
for mut generator in generators.into_iter() { for mut generator in generators {
let registry = registry.clone(); let registry = registry.clone();
let registry2 = registry.clone(); let registry2 = registry.clone();
let f = f.clone(); let f = f.clone();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
registry.worker_thread(generator.as_mut(), f); registry.worker_thread(generator.as_mut(), &f);
}); });
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
if let Err(e) = handle.join() { if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() { if let Some(e) = e.downcast_ref::<&'static str>() {
eprintln!("Got an error: {}", e); eprintln!("Got an error: {e}");
} else { } else {
eprintln!("Got an unknown error: {:?}", e); eprintln!("Got an unknown error: {e:?}");
} }
registry2.panicked.store(true, Ordering::SeqCst); registry2.panicked.store(true, Ordering::SeqCst);
registry2.wait_condvar.notify_all(); registry2.wait_condvar.notify_all();
@ -314,19 +319,17 @@ impl WorkerRegistry {
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.join().unwrap();
} }
if self.panicked.load(Ordering::SeqCst) { assert!(!self.panicked.load(Ordering::SeqCst), "tasks panicked");
panic!("tasks panicked");
}
} }
/// Adds a task to this [WorkerRegistry]. /// Adds a task to this [`WorkerRegistry`].
pub fn add_task(&self, task: CodeGenTask) { pub fn add_task(&self, task: CodeGenTask) {
*self.task_count.lock() += 1; *self.task_count.lock() += 1;
self.sender.send(Some(task)).unwrap(); self.sender.send(Some(task)).unwrap();
} }
/// Function executed by worker thread for generating IR for each function. /// Function executed by worker thread for generating IR for each function.
fn worker_thread<G: CodeGenerator>(&self, generator: &mut G, f: Arc<WithCall>) { fn worker_thread<G: CodeGenerator>(&self, generator: &mut G, f: &Arc<WithCall>) {
let context = Context::create(); let context = Context::create();
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
@ -359,9 +362,7 @@ impl WorkerRegistry {
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
self.wait_condvar.notify_all(); self.wait_condvar.notify_all();
} }
if !errors.is_empty() { assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n"));
panic!("Codegen error: {}", errors.into_iter().sorted().join("\n----------\n"));
}
let result = module.verify(); let result = module.verify();
if let Err(err) = result { if let Err(err) = result {
@ -419,7 +420,7 @@ fn get_llvm_type<'ctx>(
use TypeEnum::*; use TypeEnum::*;
// we assume the type cache should already contain primitive types, // we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer. // and they should be passed by value instead of passing as pointer.
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| { type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
@ -454,32 +455,31 @@ fn get_llvm_type<'ctx>(
&*definition.read() &*definition.read()
{ {
let name = unifier.stringify(ty); let name = unifier.stringify(ty);
match module.get_struct_type(&name) { if let Some(t) = module.get_struct_type(&name) {
Some(t) => t.ptr_type(AddressSpace::default()).into(), t.ptr_type(AddressSpace::default()).into()
None => { } else {
let struct_type = ctx.opaque_struct_type(&name); let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert( type_cache.insert(
unifier.get_representative(ty), unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into()
);
let fields = fields_list
.iter()
.map(|f| {
get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
primitives,
fields[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into()
} );
let fields = fields_list
.iter()
.map(|f| {
get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
primitives,
fields[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
} }
} else { } else {
unreachable!() unreachable!()
@ -517,12 +517,12 @@ fn get_llvm_type<'ctx>(
}) })
} }
/// Retrieves the [LLVM type][BasicTypeEnum] corresponding to the [Type]. /// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`].
/// ///
/// This function is used mainly to obtain the ABI representation of `ty`, e.g. a `bool` is /// This function is used mainly to obtain the ABI representation of `ty`, e.g. a `bool` is
/// would be represented by an `i1`. /// would be represented by an `i1`.
/// ///
/// The difference between the in-memory representation (as returned by [get_llvm_type]) and the /// The difference between the in-memory representation (as returned by [`get_llvm_type`]) and the
/// ABI representation is that the in-memory representation must be at least byte-sized and must /// ABI representation is that the in-memory representation must be at least byte-sized and must
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such /// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations. /// restriction for ABI representations.
@ -551,7 +551,7 @@ fn get_llvm_abi_type<'ctx>(
/// the target processor) by value, a synthetic parameter with a pointer type will be passed in the /// the target processor) by value, a synthetic parameter with a pointer type will be passed in the
/// slot of the first parameter to act as the location of which the return value is passed into. /// slot of the first parameter to act as the location of which the return value is passed into.
/// ///
/// See [https://releases.llvm.org/14.0.0/docs/LangRef.html#parameter-attributes] for more /// See <https://releases.llvm.org/14.0.0/docs/LangRef.html#parameter-attributes> for more
/// information. /// information.
fn need_sret(ty: BasicTypeEnum) -> bool { fn need_sret(ty: BasicTypeEnum) -> bool {
fn need_sret_impl(ty: BasicTypeEnum, maybe_large: bool) -> bool { fn need_sret_impl(ty: BasicTypeEnum, maybe_large: bool) -> bool {
@ -585,7 +585,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
unifier.top_level = Some(top_level_ctx.clone()); unifier.top_level = Some(top_level_ctx.clone());
let mut cache = HashMap::new(); let mut cache = HashMap::new();
for (a, b) in task.subst.iter() { for (a, b) in &task.subst {
// this should be unification between variables and concrete types // this should be unification between variables and concrete types
// and should not cause any problem... // and should not cause any problem...
let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache); let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache);
@ -599,7 +599,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
Err(err) Err(err)
} }
}) })
.unwrap() .unwrap();
} }
// rebuild primitive store with unique representatives // rebuild primitive store with unique representatives
@ -642,22 +642,21 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
(primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::default()).into()), (primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::default()).into()),
(primitives.exception, { (primitives.exception, {
let name = "Exception"; let name = "Exception";
match module.get_struct_type(name) { if let Some(t) = module.get_struct_type(name) {
Some(t) => t.ptr_type(AddressSpace::default()).as_basic_type_enum(), t.ptr_type(AddressSpace::default()).as_basic_type_enum()
None => { } else {
let exception = context.opaque_struct_type("Exception"); let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into(); let int32 = context.i32_type().into();
let int64 = context.i64_type().into(); let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum(); let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64]; let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false); exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum() exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
} }
}) })
] ]
.iter() .iter()
.cloned() .copied()
.collect(); .collect();
// NOTE: special handling of option cannot use this type cache since it contains type var, // NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead // handled inside get_llvm_type instead
@ -733,7 +732,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let body_bb = context.append_basic_block(fn_val, "body"); let body_bb = context.append_basic_block(fn_val, "body");
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = if has_sret { 1 } else { 0 }; let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate() { for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(
@ -779,7 +778,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let store = registry.static_value_store.lock(); let store = registry.static_value_store.lock();
store.store[task.id].clone() store.store[task.id].clone()
}; };
for (k, v) in static_values.into_iter() { for (k, v) in static_values {
let (_, static_val, _) = var_assignment.get_mut(&args[k].name).unwrap(); let (_, static_val, _) = var_assignment.get_mut(&args[k].name).unwrap();
*static_val = Some(v); *static_val = Some(v);
} }
@ -849,19 +848,19 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
return_buffer, return_buffer,
unwind_target: None, unwind_target: None,
outer_catch_clauses: None, outer_catch_clauses: None,
const_strings: Default::default(), const_strings: HashMap::default(),
registry, registry,
var_assignment, var_assignment,
type_cache, type_cache,
primitives, primitives,
init_bb, init_bb,
exception_val: Default::default(), exception_val: Option::default(),
builder, builder,
module, module,
unifier, unifier,
static_value_store, static_value_store,
need_sret: has_sret, need_sret: has_sret,
current_loc: Default::default(), current_loc: Location::default(),
debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()),
}; };
@ -894,12 +893,12 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
/// Generates LLVM IR for a function. /// Generates LLVM IR for a function.
/// ///
/// * `context` - The [LLVM Context][Context] used in generating the function body. /// * `context` - The [LLVM Context][`Context`] used in generating the function body.
/// * `generator` - The [CodeGenerator] for generating various program constructs. /// * `generator` - The [`CodeGenerator`] for generating various program constructs.
/// * `registry` - The [WorkerRegistry] responsible for monitoring this function generation task. /// * `registry` - The [`WorkerRegistry`] responsible for monitoring this function generation task.
/// * `builder` - The [Builder] used for generating LLVM IR. /// * `builder` - The [`Builder`] used for generating LLVM IR.
/// * `module` - The [Module] of which the generated LLVM function will be inserted into. /// * `module` - The [`Module`] of which the generated LLVM function will be inserted into.
/// * `task` - The [CodeGenTask] associated with this function generation task. /// * `task` - The [`CodeGenTask`] associated with this function generation task.
/// ///
pub fn gen_func<'ctx, G: CodeGenerator>( pub fn gen_func<'ctx, G: CodeGenerator>(
context: &'ctx Context, context: &'ctx Context,
@ -917,15 +916,15 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
/// Converts the value of a boolean-like value `bool_value` into an `i1`. /// Converts the value of a boolean-like value `bool_value` into an `i1`.
fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntValue<'ctx> { fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntValue<'ctx> {
if bool_value.get_type().get_bit_width() != 1 { if bool_value.get_type().get_bit_width() == 1 {
bool_value
} else {
builder.build_int_compare( builder.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
bool_value, bool_value,
bool_value.get_type().const_zero(), bool_value.get_type().const_zero(),
"tobool" "tobool"
) )
} else {
bool_value
} }
} }
@ -966,7 +965,7 @@ fn bool_to_i8<'ctx>(
/// let cmp = lo < hi; /// let cmp = lo < hi;
/// ``` /// ```
/// ///
/// Returns an `i1` [IntValue] representing the result of whether the `value` is in the range. /// Returns an `i1` [`IntValue`] representing the result of whether the `value` is in the range.
fn gen_in_range_check<'ctx>( fn gen_in_range_check<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>, value: IntValue<'ctx>,

View File

@ -24,7 +24,7 @@ use nac3parser::ast::{
}; };
use std::convert::TryFrom; use std::convert::TryFrom;
/// See [CodeGenerator::gen_var_alloc]. /// See [`CodeGenerator::gen_var_alloc`].
pub fn gen_var<'ctx>( pub fn gen_var<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
@ -54,7 +54,7 @@ pub fn gen_var<'ctx>(
Ok(ptr) Ok(ptr)
} }
/// See [CodeGenerator::gen_store_target]. /// See [`CodeGenerator::gen_store_target`].
pub fn gen_store_target<'ctx, G: CodeGenerator>( pub fn gen_store_target<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -84,9 +84,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} else { } else {
return Ok(None) return Ok(None)
}; };
let ptr = if let BasicValueEnum::PointerValue(v) = val { let BasicValueEnum::PointerValue(ptr) = val else {
v
} else {
unreachable!(); unreachable!();
}; };
unsafe { unsafe {
@ -164,7 +162,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
})) }))
} }
/// See [CodeGenerator::gen_assign]. /// See [`CodeGenerator::gen_assign`].
pub fn gen_assign<'ctx, G: CodeGenerator>( pub fn gen_assign<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -212,14 +210,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {
return Ok(()) return Ok(())
}; };
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind) list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
} else { } else {
unreachable!() unreachable!()
} }
} }
_ => { _ => {
let name = if let ExprKind::Name { id, .. } = &target.node { let name = if let ExprKind::Name { id, .. } = &target.node {
format!("{}.addr", id) format!("{id}.addr")
} else { } else {
String::from("target.addr") String::from("target.addr")
}; };
@ -241,7 +239,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [CodeGenerator::gen_for]. /// See [`CodeGenerator::gen_for`].
pub fn gen_for<G: CodeGenerator>( pub fn gen_for<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -255,7 +253,7 @@ pub fn gen_for<G: CodeGenerator>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let zero = int32.const_zero(); let zero = int32.const_zero();
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let body_bb = ctx.ctx.append_basic_block(current, "for.body");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// if there is no orelse, we just go to cont_bb // if there is no orelse, we just go to cont_bb
@ -368,7 +366,7 @@ pub fn gen_for<G: CodeGenerator>(
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -387,7 +385,7 @@ pub fn gen_for<G: CodeGenerator>(
} }
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -402,7 +400,7 @@ pub fn gen_for<G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [CodeGenerator::gen_while]. /// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>( pub fn gen_while<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -441,7 +439,7 @@ pub fn gen_while<G: CodeGenerator>(
}; };
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -457,7 +455,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.build_unconditional_branch(cont_bb);
} }
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -471,7 +469,7 @@ pub fn gen_while<G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [CodeGenerator::gen_if]. /// See [`CodeGenerator::gen_if`].
pub fn gen_if<G: CodeGenerator>( pub fn gen_if<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -503,7 +501,7 @@ pub fn gen_if<G: CodeGenerator>(
}; };
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -529,7 +527,7 @@ pub fn gen_if<G: CodeGenerator>(
if let Some(cont_bb) = cont_bb { if let Some(cont_bb) = cont_bb {
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
@ -571,8 +569,7 @@ pub fn get_builtins<'ctx>(
.ctx .ctx
.void_type() .void_type()
.fn_type(&[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], false), .fn_type(&[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], false),
"__nac3_resume" => ctx.ctx.void_type().fn_type(&[], false), "__nac3_resume" | "__nac3_end_catch" => ctx.ctx.void_type().fn_type(&[], false),
"__nac3_end_catch" => ctx.ctx.void_type().fn_type(&[], false),
_ => unimplemented!(), _ => unimplemented!(),
}; };
let fun = ctx.module.add_function(symbol, ty, None); let fun = ctx.module.add_function(symbol, ty, None);
@ -613,20 +610,20 @@ pub fn exn_constructor<'ctx>(
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id");
let id = ctx.resolver.get_string_id(&exception_name); let id = ctx.resolver.get_string_id(&exception_name);
ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)); ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false));
let empty_string = ctx.gen_const(generator, &Constant::Str("".into()), ctx.primitives.str); let empty_string = ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str);
let ptr = let ptr =
ctx.builder.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg"); ctx.builder.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg");
let msg = if !args.is_empty() { let msg = if args.is_empty() {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.str)?
} else {
empty_string.unwrap() empty_string.unwrap()
} else {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.str)?
}; };
ctx.builder.build_store(ptr, msg); ctx.builder.build_store(ptr, msg);
for i in [6, 7, 8].iter() { for i in &[6, 7, 8] {
let value = if !args.is_empty() { let value = if args.is_empty() {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)?
} else {
ctx.ctx.i64_type().const_zero().into() ctx.ctx.i64_type().const_zero().into()
} else {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)?
}; };
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, zelf,
@ -636,7 +633,7 @@ pub fn exn_constructor<'ctx>(
ctx.builder.build_store(ptr, value); ctx.builder.build_store(ptr, value);
} }
// set file, func to empty string // set file, func to empty string
for i in [1, 4].iter() { for i in &[1, 4] {
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, zelf,
&[zero, int32.const_int(*i, false)], &[zero, int32.const_int(*i, false)],
@ -645,7 +642,7 @@ pub fn exn_constructor<'ctx>(
ctx.builder.build_store(ptr, empty_string.unwrap()); ctx.builder.build_store(ptr, empty_string.unwrap());
} }
// set ints to zero // set ints to zero
for i in [2, 3].iter() { for i in &[2, 3] {
let ptr = ctx.builder.build_in_bounds_gep( let ptr = ctx.builder.build_in_bounds_gep(
zelf, zelf,
&[zero, int32.const_int(*i, false)], &[zero, int32.const_int(*i, false)],
@ -768,7 +765,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
let mut clauses = Vec::new(); let mut clauses = Vec::new();
let mut found_catch_all = false; let mut found_catch_all = false;
for handler_node in handlers.iter() { for handler_node in handlers {
let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node;
// none or Exception // none or Exception
if type_.is_none() if type_.is_none()
@ -779,30 +776,30 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
clauses.push(None); clauses.push(None);
found_catch_all = true; found_catch_all = true;
break; break;
} else {
let type_ = type_.as_ref().unwrap();
let exn_name = ctx.resolver.get_type_name(
&ctx.top_level.definitions.read(),
&mut ctx.unifier,
type_.custom.unwrap(),
);
let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id
} else {
unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name);
let exn_id_global =
ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{}", exn_id));
exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false));
clauses.push(Some(exn_id_global.as_pointer_value().as_basic_value_enum()));
} }
let type_ = type_.as_ref().unwrap();
let exn_name = ctx.resolver.get_type_name(
&ctx.top_level.definitions.read(),
&mut ctx.unifier,
type_.custom.unwrap(),
);
let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id
} else {
unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name);
let exn_id_global =
ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{exn_id}"));
exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false));
clauses.push(Some(exn_id_global.as_pointer_value().as_basic_value_enum()));
} }
let mut all_clauses = clauses.clone(); let mut all_clauses = clauses.clone();
if let Some(old_clauses) = &ctx.outer_catch_clauses { if let Some(old_clauses) = &ctx.outer_catch_clauses {
if !found_catch_all { if !found_catch_all {
all_clauses.extend_from_slice(&old_clauses.0) all_clauses.extend_from_slice(&old_clauses.0);
} }
} }
let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn)); let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn));
@ -819,7 +816,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx.return_target = old_return; ctx.return_target = old_return;
ctx.loop_target = old_loop_target.or(ctx.loop_target).take(); ctx.loop_target = old_loop_target.or(ctx.loop_target).take();
let old_unwind = if !finalbody.is_empty() { let old_unwind = if finalbody.is_empty() {
None
} else {
let final_landingpad = ctx.ctx.append_basic_block(current_fun, "try.catch.final"); let final_landingpad = ctx.ctx.append_basic_block(current_fun, "try.catch.final");
ctx.builder.position_at_end(final_landingpad); ctx.builder.position_at_end(final_landingpad);
ctx.builder.build_landing_pad( ctx.builder.build_landing_pad(
@ -832,8 +831,6 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx.builder.build_unconditional_branch(cleanup.unwrap()); ctx.builder.build_unconditional_branch(cleanup.unwrap());
ctx.builder.position_at_end(body); ctx.builder.position_at_end(body);
ctx.unwind_target.replace(final_landingpad) ctx.unwind_target.replace(final_landingpad)
} else {
None
}; };
// run end_catch before continue/break/return // run end_catch before continue/break/return
@ -886,7 +883,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
let mut post_handlers = Vec::new(); let mut post_handlers = Vec::new();
let exnid = if !handlers.is_empty() { let exnid = if handlers.is_empty() {
None
} else {
ctx.builder.position_at_end(dispatcher); ctx.builder.position_at_end(dispatcher);
unsafe { unsafe {
let zero = ctx.ctx.i32_type().const_zero(); let zero = ctx.ctx.i32_type().const_zero();
@ -897,8 +896,6 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
); );
Some(ctx.builder.build_load(exnid_ptr, "exnid")) Some(ctx.builder.build_load(exnid_ptr, "exnid"))
} }
} else {
None
}; };
for (handler_node, exn_type) in handlers.iter().zip(clauses.iter()) { for (handler_node, exn_type) in handlers.iter().zip(clauses.iter()) {
@ -1011,7 +1008,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
let dest = ctx.builder.build_load(final_state, "final_dest"); let dest = ctx.builder.build_load(final_state, "final_dest");
ctx.builder.build_indirect_branch(dest, &final_targets); ctx.builder.build_indirect_branch(dest, &final_targets);
} }
for block in final_paths.iter() { for block in &final_paths {
if block.get_terminator().is_none() { if block.get_terminator().is_none() {
ctx.builder.position_at_end(*block); ctx.builder.position_at_end(*block);
ctx.builder.build_unconditional_branch(finalizer); ctx.builder.build_unconditional_branch(finalizer);
@ -1034,7 +1031,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
} }
} }
/// See [CodeGenerator::gen_with]. /// See [`CodeGenerator::gen_with`].
pub fn gen_with<G: CodeGenerator>( pub fn gen_with<G: CodeGenerator>(
_: &mut G, _: &mut G,
_: &mut CodeGenContext<'_, '_>, _: &mut CodeGenContext<'_, '_>,
@ -1050,7 +1047,7 @@ pub fn gen_return<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
value: &Option<Box<Expr<Option<Type>>>>, value: &Option<Box<Expr<Option<Type>>>>,
) -> Result<(), String> { ) -> Result<(), String> {
let func = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap(); let func = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let value = if let Some(v_expr) = value.as_ref() { let value = if let Some(v_expr) = value.as_ref() {
if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() { if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() {
Some( Some(
@ -1096,7 +1093,7 @@ pub fn gen_return<G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [CodeGenerator::gen_stmt]. /// See [`CodeGenerator::gen_stmt`].
pub fn gen_stmt<G: CodeGenerator>( pub fn gen_stmt<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -1133,7 +1130,7 @@ pub fn gen_stmt<G: CodeGenerator>(
let Some(value) = generator.gen_expr(ctx, value)? else { let Some(value) = generator.gen_expr(ctx, value)? else {
return Ok(()) return Ok(())
}; };
for target in targets.iter() { for target in targets {
generator.gen_assign(ctx, target, value.clone())?; generator.gen_assign(ctx, target, value.clone())?;
} }
} }
@ -1185,7 +1182,7 @@ pub fn gen_stmt<G: CodeGenerator>(
err_msg, err_msg,
[None, None, None], [None, None, None],
stmt.location, stmt.location,
) );
} }
_ => unimplemented!() _ => unimplemented!()
}; };

View File

@ -35,10 +35,10 @@ pub enum SymbolValue {
} }
impl SymbolValue { impl SymbolValue {
/// Creates a [SymbolValue] from a [Constant]. /// Creates a [`SymbolValue`] from a [`Constant`].
/// ///
/// * `constant` - The constant to create the value from. /// * `constant` - The constant to create the value from.
/// * `expected_ty` - The expected type of the [SymbolValue]. /// * `expected_ty` - The expected type of the [`SymbolValue`].
pub fn from_constant( pub fn from_constant(
constant: &Constant, constant: &Constant,
expected_ty: Type, expected_ty: Type,
@ -50,21 +50,21 @@ impl SymbolValue {
if unifier.unioned(expected_ty, primitives.option) { if unifier.unioned(expected_ty, primitives.option) {
Ok(SymbolValue::OptionNone) Ok(SymbolValue::OptionNone)
} else { } else {
Err(format!("Expected {:?}, but got Option", expected_ty)) Err(format!("Expected {expected_ty:?}, but got Option"))
} }
} }
Constant::Bool(b) => { Constant::Bool(b) => {
if unifier.unioned(expected_ty, primitives.bool) { if unifier.unioned(expected_ty, primitives.bool) {
Ok(SymbolValue::Bool(*b)) Ok(SymbolValue::Bool(*b))
} else { } else {
Err(format!("Expected {:?}, but got bool", expected_ty)) Err(format!("Expected {expected_ty:?}, but got bool"))
} }
} }
Constant::Str(s) => { Constant::Str(s) => {
if unifier.unioned(expected_ty, primitives.str) { if unifier.unioned(expected_ty, primitives.str) {
Ok(SymbolValue::Str(s.to_string())) Ok(SymbolValue::Str(s.to_string()))
} else { } else {
Err(format!("Expected {:?}, but got str", expected_ty)) Err(format!("Expected {expected_ty:?}, but got str"))
} }
}, },
Constant::Int(i) => { Constant::Int(i) => {
@ -107,14 +107,14 @@ impl SymbolValue {
if unifier.unioned(expected_ty, primitives.float) { if unifier.unioned(expected_ty, primitives.float) {
Ok(SymbolValue::Double(*f)) Ok(SymbolValue::Double(*f))
} else { } else {
Err(format!("Expected {:?}, but got float", expected_ty)) Err(format!("Expected {expected_ty:?}, but got float"))
} }
}, },
_ => Err(format!("Unsupported value type {:?}", constant)), _ => Err(format!("Unsupported value type {constant:?}")),
} }
} }
/// Returns the [Type] representing the data type of this value. /// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self { match self {
SymbolValue::I32(_) => primitives.int32, SymbolValue::I32(_) => primitives.int32,
@ -133,12 +133,11 @@ impl SymbolValue {
ty: vs_tys, ty: vs_tys,
}) })
} }
SymbolValue::OptionSome(_) => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
SymbolValue::OptionNone => primitives.option,
} }
} }
/// Returns the [TypeAnnotation] representing the data type of this value. /// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
match self { match self {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
@ -157,7 +156,7 @@ impl SymbolValue {
} }
SymbolValue::OptionNone => TypeAnnotation::CustomClass { SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier), id: primitives.option.get_obj_id(unifier),
params: Default::default(), params: Vec::default(),
}, },
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier); let ty = v.get_type_annotation(primitives, unifier);
@ -169,7 +168,7 @@ impl SymbolValue {
} }
} }
/// Returns the [TypeEnum] representing the data type of this value. /// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> { pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier); let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty) unifier.get_ty(ty)
@ -179,12 +178,12 @@ impl SymbolValue {
impl Display for SymbolValue { impl Display for SymbolValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
SymbolValue::I32(i) => write!(f, "{}", i), SymbolValue::I32(i) => write!(f, "{i}"),
SymbolValue::I64(i) => write!(f, "int64({})", i), SymbolValue::I64(i) => write!(f, "int64({i})"),
SymbolValue::U32(i) => write!(f, "uint32({})", i), SymbolValue::U32(i) => write!(f, "uint32({i})"),
SymbolValue::U64(i) => write!(f, "uint64({})", i), SymbolValue::U64(i) => write!(f, "uint64({i})"),
SymbolValue::Str(s) => write!(f, "\"{}\"", s), SymbolValue::Str(s) => write!(f, "\"{s}\""),
SymbolValue::Double(d) => write!(f, "{}", d), SymbolValue::Double(d) => write!(f, "{d}"),
SymbolValue::Bool(b) => { SymbolValue::Bool(b) => {
if *b { if *b {
write!(f, "True") write!(f, "True")
@ -193,9 +192,9 @@ impl Display for SymbolValue {
} }
} }
SymbolValue::Tuple(t) => { SymbolValue::Tuple(t) => {
write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join(", ")) write!(f, "({})", t.iter().map(|v| format!("{v}")).collect::<Vec<_>>().join(", "))
} }
SymbolValue::OptionSome(v) => write!(f, "Some({})", v), SymbolValue::OptionSome(v) => write!(f, "Some({v})"),
SymbolValue::OptionNone => write!(f, "none"), SymbolValue::OptionNone => write!(f, "none"),
} }
} }
@ -212,7 +211,7 @@ pub trait StaticValue {
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>; ) -> BasicValueEnum<'ctx>;
/// Converts this value to a LLVM [BasicValueEnum]. /// Converts this value to a LLVM [`BasicValueEnum`].
fn to_basic_value_enum<'ctx>( fn to_basic_value_enum<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -272,7 +271,7 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
impl<'ctx> ValueEnum<'ctx> { impl<'ctx> ValueEnum<'ctx> {
/// Converts this [ValueEnum] to a [BasicValueEnum]. /// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>( pub fn to_basic_value_enum<'a>(
self, self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -376,39 +375,36 @@ pub fn parse_type_annotation<T>(
Ok(primitives.exception) Ok(primitives.exception)
} else { } else {
let obj_id = resolver.get_identifier_def(*id); let obj_id = resolver.get_identifier_def(*id);
match obj_id { if let Ok(obj_id) = obj_id {
Ok(obj_id) => { let def = top_level_defs[obj_id.0].read();
let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if !type_vars.is_empty() {
if !type_vars.is_empty() { return Err(format!(
return Err(format!( "Unexpected number of type parameters: expected {} but got 0",
"Unexpected number of type parameters: expected {} but got 0", type_vars.len()
type_vars.len() ));
)); }
} let fields = chain(
let fields = chain( fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
fields.iter().map(|(k, v, m)| (*k, (*v, *m))), methods.iter().map(|(k, v, _)| (*k, (*v, false))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))), )
)
.collect(); .collect();
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj {
obj_id, obj_id,
fields, fields,
params: Default::default(), params: HashMap::default(),
})) }))
} else { } else {
Err(format!("Cannot use function name as type at {}", loc)) Err(format!("Cannot use function name as type at {loc}"))
}
} }
Err(_) => { } else {
let ty = resolver let ty = resolver
.get_symbol_type(unifier, top_level_defs, primitives, *id) .get_symbol_type(unifier, top_level_defs, primitives, *id)
.map_err(|e| format!("Unknown type annotation at {}: {}", loc, e))?; .map_err(|e| format!("Unknown type annotation at {loc}: {e}"))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
} else { } else {
Err(format!("Unknown type annotation {} at {}", id, loc)) Err(format!("Unknown type annotation {id} at {loc}"))
}
} }
} }
} }
@ -520,7 +516,7 @@ impl dyn SymbolResolver + Send + Sync {
unreachable!("expected class definition") unreachable!("expected class definition")
} }
}, },
&mut |id| format!("typevar{}", id), &mut |id| format!("typevar{id}"),
&mut None, &mut None,
) )
} }

View File

@ -49,7 +49,7 @@ pub fn get_exn_constructor(
FuncArg { FuncArg {
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str("".into())), default_value: Some(SymbolValue::Str(String::new())),
}, },
FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) }, FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) }, FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
@ -58,20 +58,20 @@ pub fn get_exn_constructor(
let exn_type = unifier.add_ty(TypeEnum::TObj { let exn_type = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(class_id), obj_id: DefinitionId(class_id),
fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(),
params: Default::default(), params: HashMap::default(),
}); });
let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: exn_cons_args, args: exn_cons_args,
ret: exn_type, ret: exn_type,
vars: Default::default(), vars: HashMap::default(),
})); }));
let fun_def = TopLevelDef::Function { let fun_def = TopLevelDef::Function {
name: format!("{}.__init__", name), name: format!("{name}.__init__"),
simple_name: "__init__".into(), simple_name: "__init__".into(),
signature, signature,
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))),
loc: None, loc: None,
@ -79,12 +79,12 @@ pub fn get_exn_constructor(
let class_def = TopLevelDef::Class { let class_def = TopLevelDef::Class {
name: name.into(), name: name.into(),
object_id: DefinitionId(class_id), object_id: DefinitionId(class_id),
type_vars: Default::default(), type_vars: Vec::default(),
fields: exception_fields, fields: exception_fields,
methods: vec![("__init__".into(), signature, DefinitionId(cons_id))], methods: vec![("__init__".into(), signature, DefinitionId(cons_id))],
ancestors: vec![ ancestors: vec![
TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Default::default() }, TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() },
TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }, TypeAnnotation::CustomClass { id: DefinitionId(7), params: Vec::default() },
], ],
constructor: Some(signature), constructor: Some(signature),
resolver: None, resolver: None,
@ -93,7 +93,7 @@ pub fn get_exn_constructor(
(fun_def, class_def, signature, exn_type) (fun_def, class_def, signature, exn_type)
} }
/// Creates a NumPy [TopLevelDef] function by code generation. /// Creates a NumPy [`TopLevelDef`] function by code generation.
/// ///
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
/// * `ret_ty`: The return type of this function. /// * `ret_ty`: The return type of this function.
@ -120,16 +120,16 @@ fn create_fn_by_codegen(
ret: ret_ty, ret: ret_ty,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))), codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))),
loc: None, loc: None,
})) }))
} }
/// Creates a NumPy [TopLevelDef] function using an LLVM intrinsic. /// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic.
/// ///
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
/// * `ret_ty`: The return type of this function. /// * `ret_ty`: The return type of this function.
@ -191,7 +191,7 @@ fn create_fn_by_intrinsic(
) )
} }
/// Creates a unary NumPy [TopLevelDef] function using an extern function (e.g. from `libc` or /// Creates a unary NumPy [`TopLevelDef`] function using an extern function (e.g. from `libc` or
/// `libm`). /// `libm`).
/// ///
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
@ -363,9 +363,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Arc::new(RwLock::new(TopLevelDef::Class { Arc::new(RwLock::new(TopLevelDef::Class {
name: "Exception".into(), name: "Exception".into(),
object_id: DefinitionId(7), object_id: DefinitionId(7),
type_vars: Default::default(), type_vars: Vec::default(),
fields: exception_fields, fields: exception_fields,
methods: Default::default(), methods: Vec::default(),
ancestors: vec![], ancestors: vec![],
constructor: None, constructor: None,
resolver: None, resolver: None,
@ -398,7 +398,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
], ],
ancestors: vec![TypeAnnotation::CustomClass { ancestors: vec![TypeAnnotation::CustomClass {
id: DefinitionId(10), id: DefinitionId(10),
params: Default::default(), params: Vec::default(),
}], }],
constructor: None, constructor: None,
resolver: None, resolver: None,
@ -410,8 +410,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
simple_name: "is_some".into(), simple_name: "is_some".into(),
signature: is_some_ty.0, signature: is_some_ty.0,
var_id: vec![option_ty_var_id], var_id: vec![option_ty_var_id],
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, _, _, generator| { |ctx, obj, _, _, generator| {
@ -435,8 +435,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
simple_name: "is_none".into(), simple_name: "is_none".into(),
signature: is_some_ty.0, signature: is_some_ty.0,
var_id: vec![option_ty_var_id], var_id: vec![option_ty_var_id],
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, _, _, generator| { |ctx, obj, _, _, generator| {
@ -460,8 +460,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
simple_name: "unwrap".into(), simple_name: "unwrap".into(),
signature: unwrap_ty.0, signature: unwrap_ty.0,
var_id: vec![option_ty_var_id], var_id: vec![option_ty_var_id],
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| { |_, _, _, _, _| {
@ -478,9 +478,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: int32, ret: int32,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -553,9 +553,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: int64, ret: int64,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -624,9 +624,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: uint32, ret: uint32,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -701,9 +701,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: uint64, ret: uint64,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -777,9 +777,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: float, ret: float,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -927,11 +927,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
}, },
], ],
ret: range, ret: range,
vars: Default::default(), vars: HashMap::default(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| { |ctx, _, _, args, generator| {
@ -1013,11 +1013,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "s".into(), ty: string, default_value: None }], args: vec![FuncArg { name: "s".into(), ty: string, default_value: None }],
ret: string, ret: string,
vars: Default::default(), vars: HashMap::default(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1035,9 +1035,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: primitives.0.bool, ret: primitives.0.bool,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1283,8 +1283,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.collect(), .collect(),
})), })),
var_id: vec![arg_ty.1], var_id: vec![arg_ty.1],
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1305,10 +1305,10 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
None, None,
) )
.into_int_value(); .into_int_value();
if len.get_type().get_bit_width() != 32 { if len.get_type().get_bit_width() == 32 {
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
} else {
Some(len.into()) Some(len.into())
} else {
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
} }
}) })
}, },
@ -1327,9 +1327,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: num_ty.0, ret: num_ty.0,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1389,9 +1389,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: num_ty.0, ret: num_ty.0,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1448,9 +1448,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ret: num_ty.0, ret: num_ty.0,
vars: var_map.clone(), vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {
@ -1888,8 +1888,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
vars: HashMap::from([(option_ty_var_id, option_ty_var)]), vars: HashMap::from([(option_ty_var_id, option_ty_var)]),
})), })),
var_id: vec![option_ty_var_id], var_id: vec![option_ty_var_id],
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| { |ctx, _, fun, args, generator| {

View File

@ -41,13 +41,14 @@ pub struct TopLevelComposer {
impl Default for TopLevelComposer { impl Default for TopLevelComposer {
fn default() -> Self { fn default() -> Self {
Self::new(vec![], Default::default()).0 Self::new(vec![], ComposerConfig::default()).0
} }
} }
impl TopLevelComposer { impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
#[must_use]
pub fn new( pub fn new(
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>, builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
core_config: ComposerConfig, core_config: ComposerConfig,
@ -77,11 +78,11 @@ impl TopLevelComposer {
"Some".into(), "Some".into(),
"Option".into(), "Option".into(),
]); ]);
let defined_names: HashSet<String> = Default::default(); let defined_names = HashSet::default();
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default(); let method_class = HashMap::default();
let mut builtin_id: HashMap<StrRef, DefinitionId> = Default::default(); let mut builtin_id = HashMap::default();
let mut builtin_ty: HashMap<StrRef, Type> = Default::default(); let mut builtin_ty = HashMap::default();
let builtin_name_list = definition_ast_list.iter() let builtin_name_list = definition_ast_list.iter()
.map(|def_ast| match *def_ast.0.read() { .map(|def_ast| match *def_ast.0.read() {
@ -123,9 +124,9 @@ impl TopLevelComposer {
name: name.into(), name: name.into(),
simple_name: name, simple_name: name,
signature: fun_sig, signature: fun_sig,
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
var_id: Default::default(), var_id: Vec::default(),
resolver: None, resolver: None,
codegen_callback: Some(codegen_callback), codegen_callback: Some(codegen_callback),
loc: None, loc: None,
@ -151,6 +152,7 @@ impl TopLevelComposer {
) )
} }
#[must_use]
pub fn make_top_level_context(&self) -> TopLevelContext { pub fn make_top_level_context(&self) -> TopLevelContext {
TopLevelContext { TopLevelContext {
definitions: RwLock::new( definitions: RwLock::new(
@ -166,6 +168,7 @@ impl TopLevelComposer {
} }
} }
#[must_use]
pub fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> { pub fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> {
self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec()
} }
@ -176,9 +179,19 @@ impl TopLevelComposer {
&mut self, &mut self,
ast: Stmt<()>, ast: Stmt<()>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
mod_path: String, mod_path: &str,
allow_no_constructor: bool, allow_no_constructor: bool,
) -> Result<(StrRef, DefinitionId, Option<Type>), String> { ) -> Result<(StrRef, DefinitionId, Option<Type>), String> {
type MethodInfo = (
// the simple method name without class name
StrRef,
// in this top level def, method name is prefixed with the class name
Arc<RwLock<TopLevelDef>>,
DefinitionId,
Type,
Stmt<()>,
);
let defined_names = &mut self.defined_names; let defined_names = &mut self.defined_names;
match &ast.node { match &ast.node {
ast::StmtKind::ClassDef { name: class_name, bases, body, .. } => { ast::StmtKind::ClassDef { name: class_name, bases, body, .. } => {
@ -222,15 +235,6 @@ impl TopLevelComposer {
// parse class def body and register class methods into the def list. // parse class def body and register class methods into the def list.
// module's symbol resolver would not know the name of the class methods, // module's symbol resolver would not know the name of the class methods,
// thus cannot return their definition_id // thus cannot return their definition_id
type MethodInfo = (
// the simple method name without class name
StrRef,
// in this top level def, method name is prefixed with the class name
Arc<RwLock<TopLevelDef>>,
DefinitionId,
Type,
Stmt<()>,
);
let mut class_method_name_def_ids: Vec<MethodInfo> = Vec::new(); let mut class_method_name_def_ids: Vec<MethodInfo> = Vec::new();
// we do not push anything to the def list, so we keep track of the index // we do not push anything to the def list, so we keep track of the index
// and then push in the correct order after the for loop // and then push in the correct order after the for loop
@ -288,9 +292,6 @@ impl TopLevelComposer {
dummy_method_type, dummy_method_type,
b.clone(), b.clone(),
)); ));
} else {
// do nothing
continue;
} }
} }
@ -299,7 +300,7 @@ impl TopLevelComposer {
// get the methods into the top level class_def // get the methods into the top level class_def
for (name, _, id, ty, ..) in &class_method_name_def_ids { for (name, _, id, ty, ..) in &class_method_name_def_ids {
let mut class_def = class_def_ast.0.write(); let mut class_def = class_def_ast.0.write();
if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() { if let TopLevelDef::Class { methods, .. } = &mut *class_def {
methods.push((*name, *ty, *id)); methods.push((*name, *ty, *id));
self.method_class.insert(*id, DefinitionId(class_def_id)); self.method_class.insert(*id, DefinitionId(class_def_id));
} else { } else {
@ -320,7 +321,7 @@ impl TopLevelComposer {
let global_fun_name = if mod_path.is_empty() { let global_fun_name = if mod_path.is_empty() {
name.to_string() name.to_string()
} else { } else {
format!("{}.{}", mod_path, name) format!("{mod_path}.{name}")
}; };
if !defined_names.insert(global_fun_name.clone()) { if !defined_names.insert(global_fun_name.clone()) {
return Err(format!( return Err(format!(
@ -383,7 +384,7 @@ impl TopLevelComposer {
// only deal with class def here // only deal with class def here
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = { let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() { if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
if let Some(ast::Located { if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, .. node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast }) = class_ast
@ -397,7 +398,7 @@ impl TopLevelComposer {
} }
}; };
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref(); let class_resolver = &**class_resolver;
let mut is_generic = false; let mut is_generic = false;
for b in class_bases_ast { for b in class_bases_ast {
@ -415,14 +416,13 @@ impl TopLevelComposer {
) )
} => } =>
{ {
if !is_generic { if is_generic {
is_generic = true;
} else {
return Err(format!( return Err(format!(
"only single Generic[...] is allowed (at {})", "only single Generic[...] is allowed (at {})",
b.location b.location
)); ));
} }
is_generic = true;
let type_var_list: Vec<&ast::Expr<()>>; let type_var_list: Vec<&ast::Expr<()>>;
// if `class A(Generic[T, V, G])` // if `class A(Generic[T, V, G])`
@ -430,7 +430,7 @@ impl TopLevelComposer {
type_var_list = elts.iter().collect_vec(); type_var_list = elts.iter().collect_vec();
// `class A(Generic[T])` // `class A(Generic[T])`
} else { } else {
type_var_list = vec![slice.deref()]; type_var_list = vec![&**slice];
} }
// parse the type vars // parse the type vars
@ -509,7 +509,7 @@ impl TopLevelComposer {
let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = {
if let TopLevelDef::Class { if let TopLevelDef::Class {
ancestors, resolver, object_id, type_vars, .. ancestors, resolver, object_id, type_vars, ..
} = class_def.deref_mut() } = &mut *class_def
{ {
if let Some(ast::Located { if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, node: ast::StmtKind::ClassDef { bases, .. },
@ -525,7 +525,7 @@ impl TopLevelComposer {
} }
}; };
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.deref(); let class_resolver = &**class_resolver;
let mut has_base = false; let mut has_base = false;
for b in class_bases { for b in class_bases {
@ -589,11 +589,11 @@ impl TopLevelComposer {
} }
// second, get all ancestors // second, get all ancestors
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = Default::default(); let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default();
let mut get_all_ancestors = |class_def: &Arc<RwLock<TopLevelDef>>| { let mut get_all_ancestors = |class_def: &Arc<RwLock<TopLevelDef>>| {
let class_def = class_def.read(); let class_def = class_def.read();
let (class_ancestors, class_id) = { let (class_ancestors, class_id) = {
if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() { if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
(ancestors, *object_id) (ancestors, *object_id)
} else { } else {
return Ok(()); return Ok(());
@ -630,7 +630,7 @@ impl TopLevelComposer {
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_ancestors, class_id, class_type_vars) = { let (class_ancestors, class_id, class_type_vars) = {
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } =
class_def.deref_mut() &mut *class_def
{ {
(ancestors, *object_id, type_vars) (ancestors, *object_id, type_vars)
} else { } else {
@ -652,7 +652,7 @@ impl TopLevelComposer {
{ {
// if inherited from Exception, the body should be a pass // if inherited from Exception, the body should be a pass
if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node { if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node {
for stmt in body.iter() { for stmt in body {
if matches!( if matches!(
stmt.node, stmt.node,
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
@ -695,7 +695,7 @@ impl TopLevelComposer {
} }
if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { if matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
if let Err(e) = Self::analyze_single_class_methods_fields( if let Err(e) = Self::analyze_single_class_methods_fields(
class_def.clone(), class_def,
&class_ast.as_ref().unwrap().node, &class_ast.as_ref().unwrap().node,
&temp_def_list, &temp_def_list,
unifier, unifier,
@ -724,13 +724,13 @@ impl TopLevelComposer {
continue; continue;
} }
let mut class_def = class_def.write(); let mut class_def = class_def.write();
if let TopLevelDef::Class { ancestors, .. } = class_def.deref() { if let TopLevelDef::Class { ancestors, .. } = &*class_def {
// if the length of the ancestor is equal to the current depth // if the length of the ancestor is equal to the current depth
// it means that all the ancestors of the class is handled // it means that all the ancestors of the class is handled
if ancestors.len() == current_ancestor_depth { if ancestors.len() == current_ancestor_depth {
finished = false; finished = false;
Self::analyze_single_class_ancestors( Self::analyze_single_class_ancestors(
class_def.deref_mut(), &mut class_def,
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives, primitives,
@ -742,10 +742,9 @@ impl TopLevelComposer {
if finished { if finished {
break; break;
} else {
current_ancestor_depth += 1;
} }
current_ancestor_depth += 1;
if current_ancestor_depth > def_ast_list.len() + 1 { if current_ancestor_depth > def_ast_list.len() + 1 {
unreachable!("cannot be longer than the whole top level def list") unreachable!("cannot be longer than the whole top level def list")
} }
@ -764,11 +763,11 @@ impl TopLevelComposer {
errors.insert(e); errors.insert(e);
} }
} }
for ty in subst_list.unwrap().into_iter() { for ty in subst_list.unwrap() {
if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) { if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) {
let mut new_fields = HashMap::new(); let mut new_fields = HashMap::new();
let mut need_subst = false; let mut need_subst = false;
for (name, (ty, mutable)) in fields.iter() { for (name, (ty, mutable)) in fields {
let substituted = unifier.subst(*ty, params); let substituted = unifier.subst(*ty, params);
need_subst |= substituted.is_some(); need_subst |= substituted.is_some();
new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable));
@ -817,10 +816,8 @@ impl TopLevelComposer {
let mut errors = HashSet::new(); let mut errors = HashSet::new();
let mut analyze = |function_def: &Arc<RwLock<TopLevelDef>>, function_ast: &Option<Stmt>| { let mut analyze = |function_def: &Arc<RwLock<TopLevelDef>>, function_ast: &Option<Stmt>| {
let mut function_def = function_def.write(); let mut function_def = function_def.write();
let function_def = function_def.deref_mut(); let function_def = &mut *function_def;
let function_ast = if let Some(x) = function_ast.as_ref() { let Some(function_ast) = function_ast.as_ref() else {
x
} else {
// if let TopLevelDef::Function { name, .. } = `` // if let TopLevelDef::Function { name, .. } = ``
return Ok(()); return Ok(());
}; };
@ -835,13 +832,13 @@ impl TopLevelComposer {
if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node {
let resolver = resolver.as_ref(); let resolver = resolver.as_ref();
let resolver = resolver.unwrap(); let resolver = resolver.unwrap();
let resolver = resolver.deref(); let resolver = &**resolver;
let mut function_var_map: HashMap<u32, Type> = HashMap::new(); let mut function_var_map: HashMap<u32, Type> = HashMap::new();
let arg_types = { let arg_types = {
// make sure no duplicate parameter // make sure no duplicate parameter
let mut defined_parameter_name: HashSet<_> = HashSet::new(); let mut defined_parameter_name: HashSet<_> = HashSet::new();
for x in args.args.iter() { for x in &args.args {
if !defined_parameter_name.insert(x.node.arg) if !defined_parameter_name.insert(x.node.arg)
|| keyword_list.contains(&x.node.arg) || keyword_list.contains(&x.node.arg)
{ {
@ -1037,7 +1034,7 @@ impl TopLevelComposer {
} }
fn analyze_single_class_methods_fields( fn analyze_single_class_methods_fields(
class_def: Arc<RwLock<TopLevelDef>>, class_def: &Arc<RwLock<TopLevelDef>>,
class_ast: &ast::StmtKind<()>, class_ast: &ast::StmtKind<()>,
temp_def_list: &[Arc<RwLock<TopLevelDef>>], temp_def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
@ -1091,7 +1088,7 @@ impl TopLevelComposer {
// check method parameters cannot have same name // check method parameters cannot have same name
let mut defined_parameter_name: HashSet<_> = HashSet::new(); let mut defined_parameter_name: HashSet<_> = HashSet::new();
let zelf: StrRef = "self".into(); let zelf: StrRef = "self".into();
for x in args.args.iter() { for x in &args.args {
if !defined_parameter_name.insert(x.node.arg) if !defined_parameter_name.insert(x.node.arg)
|| (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf)
{ {
@ -1206,7 +1203,7 @@ impl TopLevelComposer {
// into the list for later unification // into the list for later unification
type_var_to_concrete_def type_var_to_concrete_def
.insert(dummy_func_arg.ty, type_ann.clone()); .insert(dummy_func_arg.ty, type_ann.clone());
result.push(dummy_func_arg) result.push(dummy_func_arg);
} }
} }
result result
@ -1255,7 +1252,7 @@ impl TopLevelComposer {
}; };
if let TopLevelDef::Function { var_id, .. } = if let TopLevelDef::Function { var_id, .. } =
temp_def_list.get(method_id.0).unwrap().write().deref_mut() &mut *temp_def_list.get(method_id.0).unwrap().write()
{ {
var_id.extend_from_slice(method_var_map var_id.extend_from_slice(method_var_map
.iter() .iter()
@ -1410,7 +1407,7 @@ impl TopLevelComposer {
// find if there is a method with same name in the child class // find if there is a method with same name in the child class
let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id);
for (class_method_name, class_method_ty, class_method_defid) in for (class_method_name, class_method_ty, class_method_defid) in
class_methods_def.iter() &*class_methods_def
{ {
if class_method_name == anc_method_name { if class_method_name == anc_method_name {
// ignore and handle self // ignore and handle self
@ -1424,8 +1421,7 @@ impl TopLevelComposer {
); );
if !ok { if !ok {
return Err(format!( return Err(format!(
"method {} has same name as ancestors' method, but incompatible type", "method {class_method_name} has same name as ancestors' method, but incompatible type"
class_method_name
)); ));
} }
// mark it as added // mark it as added
@ -1439,7 +1435,7 @@ impl TopLevelComposer {
} }
// add those that are not overriding method to the new_child_methods // add those that are not overriding method to the new_child_methods
for (class_method_name, class_method_ty, class_method_defid) in for (class_method_name, class_method_ty, class_method_defid) in
class_methods_def.iter() &*class_methods_def
{ {
if !is_override.contains(class_method_name) { if !is_override.contains(class_method_name) {
new_child_methods.push(( new_child_methods.push((
@ -1459,17 +1455,16 @@ impl TopLevelComposer {
for (anc_field_name, anc_field_ty, mutable) in fields { for (anc_field_name, anc_field_ty, mutable) in fields {
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
// find if there is a fields with the same name in the child class // find if there is a fields with the same name in the child class
for (class_field_name, ..) in class_fields_def.iter() { for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name { if class_field_name == anc_field_name {
return Err(format!( return Err(format!(
"field `{}` has already declared in the ancestor classes", "field `{class_field_name}` has already declared in the ancestor classes"
class_field_name
)); ));
} }
} }
new_child_fields.push(to_be_added); new_child_fields.push(to_be_added);
} }
for (class_field_name, class_field_ty, mutable) in class_fields_def.iter() { for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
if !is_override.contains(class_field_name) { if !is_override.contains(class_field_name) {
new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
} }
@ -1486,7 +1481,8 @@ impl TopLevelComposer {
Ok(()) Ok(())
} }
/// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of topleveldef::function /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
/// [`TopLevelDef::Function`]
fn analyze_function_instance(&mut self) -> Result<(), String> { fn analyze_function_instance(&mut self) -> Result<(), String> {
// first get the class constructor type correct for the following type check in function body // first get the class constructor type correct for the following type check in function body
// also do class field instantiation check // also do class field instantiation check
@ -1558,7 +1554,7 @@ impl TopLevelComposer {
FuncArg { FuncArg {
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str("".into())), default_value: Some(SymbolValue::Str(String::new())),
}, },
FuncArg { FuncArg {
name: "param0".into(), name: "param0".into(),
@ -1577,15 +1573,15 @@ impl TopLevelComposer {
}, },
], ],
ret: self_type, ret: self_type,
vars: Default::default(), vars: HashMap::default(),
})); }));
let cons_fun = TopLevelDef::Function { let cons_fun = TopLevelDef::Function {
name: format!("{}.{}", class_name, "__init__"), name: format!("{}.{}", class_name, "__init__"),
simple_name: init_str_id, simple_name: init_str_id,
signature, signature,
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))),
loc: None, loc: None,
@ -1661,7 +1657,7 @@ impl TopLevelComposer {
return Err(errors.into_iter().sorted().join("\n---------\n")); return Err(errors.into_iter().sorted().join("\n---------\n"));
} }
for (i, signature, id) in constructors.into_iter() { for (i, signature, id) in constructors {
if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write()
{ {
methods.push(( methods.push((
@ -1748,13 +1744,13 @@ impl TopLevelComposer {
}) })
.multi_cartesian_product() .multi_cartesian_product()
.collect_vec(); .collect_vec();
let mut result: Vec<HashMap<u32, Type>> = Default::default(); let mut result: Vec<HashMap<u32, Type>> = Vec::default();
for comb in var_combs { for comb in var_combs {
result.push(vars.keys().cloned().zip(comb).collect()); result.push(vars.keys().copied().zip(comb).collect());
} }
// NOTE: if is empty, means no type var, append a empty subst, ok to do this? // NOTE: if is empty, means no type var, append a empty subst, ok to do this?
if result.is_empty() { if result.is_empty() {
result.push(HashMap::new()) result.push(HashMap::new());
} }
(result, no_ranges) (result, no_ranges)
}; };
@ -1844,14 +1840,14 @@ impl TopLevelComposer {
&& matches!(&decorator_list[0].node, && matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"extern".into()) ast::ExprKind::Name{ id, .. } if id == &"extern".into())
{ {
instance_to_symbol.insert("".into(), simple_name.to_string()); instance_to_symbol.insert(String::new(), simple_name.to_string());
continue; continue;
} }
if !decorator_list.is_empty() if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node, && matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) ast::ExprKind::Name{ id, .. } if id == &"rpc".into())
{ {
instance_to_symbol.insert("".into(), simple_name.to_string()); instance_to_symbol.insert(String::new(), simple_name.to_string());
continue; continue;
} }
body body
@ -1867,15 +1863,14 @@ impl TopLevelComposer {
{ {
// check virtuals // check virtuals
let defs = ctx.definitions.read(); let defs = ctx.definitions.read();
for (subtype, base, loc) in inferencer.virtual_checks.iter() { for (subtype, base, loc) in &*inferencer.virtual_checks {
let base_id = { let base_id = {
let base = inferencer.unifier.get_ty(*base); let base = inferencer.unifier.get_ty(*base);
if let TypeEnum::TObj { obj_id, .. } = &*base { if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id *obj_id
} else { } else {
return Err(format!( return Err(format!(
"Base type should be a class (at {})", "Base type should be a class (at {loc})"
loc
)); ));
} }
}; };
@ -1887,8 +1882,7 @@ impl TopLevelComposer {
let base_repr = inferencer.unifier.stringify(*base); let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(format!( return Err(format!(
"Expected a subtype of {}, but got {} (at {})", "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"
base_repr, subtype_repr, loc
)); ));
} }
}; };
@ -1900,8 +1894,7 @@ impl TopLevelComposer {
let base_repr = inferencer.unifier.stringify(*base); let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(format!( return Err(format!(
"Expected a subtype of {}, but got {} (at {})", "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"
base_repr, subtype_repr, loc
)); ));
} }
} else { } else {
@ -1922,7 +1915,7 @@ impl TopLevelComposer {
unreachable!("must be class id here") unreachable!("must be class id here")
} }
}, },
&mut |id| format!("typevar{}", id), &mut |id| format!("typevar{id}"),
&mut None, &mut None,
); );
return Err(format!( return Err(format!(
@ -1934,7 +1927,7 @@ impl TopLevelComposer {
} }
instance_to_stmt.insert( instance_to_stmt.insert(
get_subst_key(unifier, self_type, &subst, Some(&vars.keys().cloned().collect())), get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())),
FunInstance { FunInstance {
body: Arc::new(fun_body), body: Arc::new(fun_body),
unifier_id: 0, unifier_id: 0,

View File

@ -43,6 +43,7 @@ impl TopLevelDef {
} }
impl TopLevelComposer { impl TopLevelComposer {
#[must_use]
pub fn make_primitives() -> (PrimitiveStore, Unifier) { pub fn make_primitives() -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
@ -134,22 +135,23 @@ impl TopLevelComposer {
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
int64, int64,
uint32,
uint64,
float, float,
bool, bool,
none, none,
range, range,
str, str,
exception, exception,
uint32,
uint64,
option, option,
}; };
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
(primitives, unifier) (primitives, unifier)
} }
/// already include the definition_id of itself inside the ancestors vector /// already include the `definition_id` of itself inside the ancestors vector
/// when first registering, the type_vars, fields, methods, ancestors are invalid /// when first registering, the `type_vars`, fields, methods, ancestors are invalid
#[must_use]
pub fn make_top_level_class_def( pub fn make_top_level_class_def(
index: usize, index: usize,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
@ -160,10 +162,10 @@ impl TopLevelComposer {
TopLevelDef::Class { TopLevelDef::Class {
name, name,
object_id: DefinitionId(index), object_id: DefinitionId(index),
type_vars: Default::default(), type_vars: Vec::default(),
fields: Default::default(), fields: Vec::default(),
methods: Default::default(), methods: Vec::default(),
ancestors: Default::default(), ancestors: Vec::default(),
constructor, constructor,
resolver, resolver,
loc, loc,
@ -171,6 +173,7 @@ impl TopLevelComposer {
} }
/// when first registering, the type is a invalid value /// when first registering, the type is a invalid value
#[must_use]
pub fn make_top_level_function_def( pub fn make_top_level_function_def(
name: String, name: String,
simple_name: StrRef, simple_name: StrRef,
@ -182,15 +185,16 @@ impl TopLevelComposer {
name, name,
simple_name, simple_name,
signature: ty, signature: ty,
var_id: Default::default(), var_id: Vec::default(),
instance_to_symbol: Default::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: Default::default(), instance_to_stmt: HashMap::default(),
resolver, resolver,
codegen_callback: None, codegen_callback: None,
loc, loc,
} }
} }
#[must_use]
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String { pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
class_name.push('.'); class_name.push('.');
class_name.push_str(method_name); class_name.push_str(method_name);
@ -206,7 +210,7 @@ impl TopLevelComposer {
return Ok((*ty, *def_id)); return Ok((*ty, *def_id));
} }
} }
Err(format!("no method {} in the current class", method_name)) Err(format!("no method {method_name} in the current class"))
} }
/// get all base class def id of a class, excluding itself. \ /// get all base class def id of a class, excluding itself. \
@ -257,17 +261,17 @@ impl TopLevelComposer {
let child_def = temp_def_list.get(child_id.0).unwrap(); let child_def = temp_def_list.get(child_id.0).unwrap();
let child_def = child_def.read(); let child_def = child_def.read();
if let TopLevelDef::Class { ancestors, .. } = &*child_def { if let TopLevelDef::Class { ancestors, .. } = &*child_def {
if !ancestors.is_empty() { if ancestors.is_empty() {
Some(ancestors[0].clone())
} else {
None None
} else {
Some(ancestors[0].clone())
} }
} else { } else {
unreachable!("child must be top level class def") unreachable!("child must be top level class def")
} }
} }
/// get the var_id of a given TVar type /// get the `var_id` of a given `TVar` type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> { pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id) Ok(*id)
@ -376,14 +380,14 @@ impl TopLevelComposer {
ast::StmtKind::If { body, orelse, .. } => { ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned() .copied()
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
} }
ast::StmtKind::Try { body, orelse, finalbody, .. } => { ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned() .copied()
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
@ -391,9 +395,9 @@ impl TopLevelComposer {
ast::StmtKind::With { body, .. } => { ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(body.as_slice())?);
} }
ast::StmtKind::Pass { .. } => {} ast::StmtKind::Pass { .. }
ast::StmtKind::Assert { .. } => {} | ast::StmtKind::Assert { .. }
ast::StmtKind::Expr { .. } => {} | ast::StmtKind::Expr { .. } => {}
_ => { _ => {
unimplemented!() unimplemented!()
@ -448,14 +452,14 @@ impl TopLevelComposer {
} }
let found = val.get_type_annotation(primitive, unifier); let found = val.get_type_annotation(primitive, unifier);
if !is_compatible(&found, ty, unifier, primitive) { if is_compatible(&found, ty, unifier, primitive) {
Ok(())
} else {
Err(format!( Err(format!(
"incompatible default parameter type, expect {}, found {}", "incompatible default parameter type, expect {}, found {}",
ty.stringify(unifier), ty.stringify(unifier),
found.stringify(unifier), found.stringify(unifier),
)) ))
} else {
Ok(())
} }
} }
} }
@ -470,7 +474,7 @@ pub fn parse_parameter_default_value(
if let Ok(v) = (*v).try_into() { if let Ok(v) = (*v).try_into() {
Ok(SymbolValue::I32(v)) Ok(SymbolValue::I32(v))
} else { } else {
Err(format!("integer value out of range at {}", loc)) Err(format!("integer value out of range at {loc}"))
} }
} }
Constant::Float(v) => Ok(SymbolValue::Double(*v)), Constant::Float(v) => Ok(SymbolValue::Double(*v)),
@ -479,8 +483,7 @@ pub fn parse_parameter_default_value(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?, tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)), )),
Constant::None => Err(format!( Constant::None => Err(format!(
"`None` is not supported, use `none` for option type instead ({})", "`None` is not supported, use `none` for option type instead ({loc})"
loc
)), )),
_ => unimplemented!("this constant is not supported at {}", loc), _ => unimplemented!("this constant is not supported at {}", loc),
} }

View File

@ -3,7 +3,6 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
fmt::Debug, fmt::Debug,
iter::FromIterator, iter::FromIterator,
ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
}; };
@ -49,6 +48,7 @@ pub struct GenCall {
} }
impl GenCall { impl GenCall {
#[must_use]
pub fn new(fp: GenCallCallback) -> GenCall { pub fn new(fp: GenCallCallback) -> GenCall {
GenCall { fp } GenCall { fp }
} }

View File

@ -33,17 +33,16 @@ impl TypeAnnotation {
match self { match self {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => { CustomClass { id, params } => {
let class_name = match unifier.top_level { let class_name = if let Some(ref top) = unifier.top_level {
Some(ref top) => { if let TopLevelDef::Class { name, .. } =
if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read()
&*top.definitions.read()[id.0].read() {
{ (*name).into()
(*name).into() } else {
} else { unreachable!()
unreachable!()
}
} }
None => format!("class_def_{}", id.0), } else {
format!("class_def_{}", id.0)
}; };
format!( format!(
"{}{}", "{}{}",
@ -51,9 +50,9 @@ impl TypeAnnotation {
{ {
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "); let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() { if param_list.is_empty() {
"".into() String::new()
} else { } else {
format!("[{}]", param_list) format!("[{param_list}]")
} }
} }
) )
@ -68,12 +67,12 @@ impl TypeAnnotation {
} }
} }
/// Parses an AST expression `expr` into a [TypeAnnotation]. /// Parses an AST expression `expr` into a [`TypeAnnotation`].
/// ///
/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all /// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition. /// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass /// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [None] when this function is invoked externally. /// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -102,7 +101,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else if id == &"str".into() { } else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str)) Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() { } else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }) Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) { } else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
@ -356,7 +355,7 @@ pub fn get_type_from_type_annotation_kinds(
match ann { match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read(); let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = def_read.deref(); let class_def: &TopLevelDef = &def_read;
let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else { let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
unreachable!("should be class def here") unreachable!("should be class def here")
}; };
@ -406,8 +405,8 @@ pub fn get_type_from_type_annotation_kinds(
"cannot apply type {} to type variable with id {:?}", "cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify( unifier.internal_stringify(
p, p,
&mut |id| format!("class{}", id), &mut |id| format!("class{id}"),
&mut |id| format!("typevar{}", id), &mut |id| format!("typevar{id}"),
&mut None &mut None
), ),
*id *id
@ -521,9 +520,10 @@ pub fn get_type_from_type_annotation_kinds(
/// considered to be type variables associated with the class \ /// considered to be type variables associated with the class \
/// \ /// \
/// But note that here we do not make a duplication of `T`, `V`, we directly /// But note that here we do not make a duplication of `T`, `V`, we directly
/// use them as they are in the TopLevelDef::Class since those in the /// use them as they are in the [`TopLevelDef::Class`] since those in the
/// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations /// `TopLevelDef::Class.type_vars` will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation /// the Type of their fields and methods will also be subst when application/instantiation
#[must_use]
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: object_id, id: object_id,
@ -534,21 +534,19 @@ pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) ->
/// get all the occurences of type vars contained in a type annotation /// get all the occurences of type vars contained in a type annotation
/// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G] /// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G]
/// this function will not make a duplicate of type var /// this function will not make a duplicate of type var
#[must_use]
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> { pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
let mut result: Vec<TypeAnnotation> = Vec::new(); let mut result: Vec<TypeAnnotation> = Vec::new();
match ann { match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()), TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) => { TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref())) result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
} }
TypeAnnotation::CustomClass { params, .. } => { TypeAnnotation::CustomClass { params, .. } => {
for p in params { for p in params {
result.extend(get_type_var_contained_in_type_annotation(p)); result.extend(get_type_var_contained_in_type_annotation(p));
} }
} }
TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()))
}
TypeAnnotation::Tuple(anns) => { TypeAnnotation::Tuple(anns) => {
for a in anns { for a in anns {
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
@ -569,9 +567,9 @@ pub fn check_overload_type_annotation_compatible(
(TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b, (TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b,
(TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => { (TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => {
let a = unifier.get_ty(*a); let a = unifier.get_ty(*a);
let a = a.deref(); let a = &*a;
let b = unifier.get_ty(*b); let b = unifier.get_ty(*b);
let b = b.deref(); let b = &*b;
if let ( if let (
TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. },

View File

@ -30,7 +30,7 @@ impl<'a> Inferencer<'a> {
Ok(()) Ok(())
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() { for elt in elts {
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
@ -98,7 +98,7 @@ impl<'a> Inferencer<'a> {
ExprKind::List { elts, .. } ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. } | ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => { | ExprKind::BoolOp { values: elts, .. } => {
for elt in elts.iter() { for elt in elts {
self.check_expr(elt, defined_identifiers)?; self.check_expr(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
@ -116,9 +116,8 @@ impl<'a> Inferencer<'a> {
// Check whether a bitwise shift has a negative RHS constant value // Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift { if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node { if let ExprKind::Constant { value, .. } = &right.node {
let rhs_val = match value { let Constant::Int(rhs_val) = value else {
Constant::Int(v) => v, unreachable!()
_ => unreachable!(),
}; };
if *rhs_val < 0 { if *rhs_val < 0 {
@ -158,7 +157,7 @@ impl<'a> Inferencer<'a> {
} }
ExprKind::Lambda { args, body } => { ExprKind::Lambda { args, body } => {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
for arg in args.args.iter() { for arg in &args.args {
// TODO: should we check the types here? // TODO: should we check the types here?
if !defined_identifiers.contains(&arg.node.arg) { if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg); defined_identifiers.insert(arg.node.arg);
@ -207,13 +206,13 @@ impl<'a> Inferencer<'a> {
self.check_expr(iter, defined_identifiers)?; self.check_expr(iter, defined_identifiers)?;
self.should_have_value(iter)?; self.should_have_value(iter)?;
let mut local_defined_identifiers = defined_identifiers.clone(); let mut local_defined_identifiers = defined_identifiers.clone();
for stmt in orelse.iter() { for stmt in orelse {
self.check_stmt(stmt, &mut local_defined_identifiers)?; self.check_stmt(stmt, &mut local_defined_identifiers)?;
} }
let mut local_defined_identifiers = defined_identifiers.clone(); let mut local_defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut local_defined_identifiers)?; self.check_pattern(target, &mut local_defined_identifiers)?;
self.should_have_value(target)?; self.should_have_value(target)?;
for stmt in body.iter() { for stmt in body {
self.check_stmt(stmt, &mut local_defined_identifiers)?; self.check_stmt(stmt, &mut local_defined_identifiers)?;
} }
Ok(false) Ok(false)
@ -226,7 +225,7 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?; let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.iter() { for ident in &body_identifiers {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.insert(*ident); defined_identifiers.insert(*ident);
} }
@ -243,7 +242,7 @@ impl<'a> Inferencer<'a> {
} }
StmtKind::With { items, body, .. } => { StmtKind::With { items, body, .. } => {
let mut new_defined_identifiers = defined_identifiers.clone(); let mut new_defined_identifiers = defined_identifiers.clone();
for item in items.iter() { for item in items {
self.check_expr(&item.context_expr, defined_identifiers)?; self.check_expr(&item.context_expr, defined_identifiers)?;
if let Some(var) = item.optional_vars.as_ref() { if let Some(var) = item.optional_vars.as_ref() {
self.check_pattern(var, &mut new_defined_identifiers)?; self.check_pattern(var, &mut new_defined_identifiers)?;
@ -255,7 +254,7 @@ impl<'a> Inferencer<'a> {
StmtKind::Try { body, handlers, orelse, finalbody, .. } => { StmtKind::Try { body, handlers, orelse, finalbody, .. } => {
self.check_block(body, &mut defined_identifiers.clone())?; self.check_block(body, &mut defined_identifiers.clone())?;
self.check_block(orelse, &mut defined_identifiers.clone())?; self.check_block(orelse, &mut defined_identifiers.clone())?;
for handler in handlers.iter() { for handler in handlers {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name { if let Some(name) = name {
@ -312,7 +311,7 @@ impl<'a> Inferencer<'a> {
let mut ret = false; let mut ret = false;
for stmt in block { for stmt in block {
if ret { if ret {
println!("warning: dead code at {:?}\n", stmt.location) println!("warning: dead code at {:?}\n", stmt.location);
} }
if self.check_stmt(stmt, defined_identifiers)? { if self.check_stmt(stmt, defined_identifiers)? {
ret = true; ret = true;

View File

@ -7,6 +7,7 @@ use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
#[must_use]
pub fn binop_name(op: &Operator) -> &'static str { pub fn binop_name(op: &Operator) -> &'static str {
match op { match op {
Operator::Add => "__add__", Operator::Add => "__add__",
@ -25,6 +26,7 @@ pub fn binop_name(op: &Operator) -> &'static str {
} }
} }
#[must_use]
pub fn binop_assign_name(op: &Operator) -> &'static str { pub fn binop_assign_name(op: &Operator) -> &'static str {
match op { match op {
Operator::Add => "__iadd__", Operator::Add => "__iadd__",
@ -43,6 +45,7 @@ pub fn binop_assign_name(op: &Operator) -> &'static str {
} }
} }
#[must_use]
pub fn unaryop_name(op: &Unaryop) -> &'static str { pub fn unaryop_name(op: &Unaryop) -> &'static str {
match op { match op {
Unaryop::UAdd => "__pos__", Unaryop::UAdd => "__pos__",
@ -52,6 +55,7 @@ pub fn unaryop_name(op: &Unaryop) -> &'static str {
} }
} }
#[must_use]
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
match op { match op {
Cmpop::Lt => Some("__lt__"), Cmpop::Lt => Some("__lt__"),
@ -183,7 +187,7 @@ pub fn impl_cmpop(
}); });
} }
/// Add, Sub, Mult /// `Add`, `Sub`, `Mult`
pub fn impl_basic_arithmetic( pub fn impl_basic_arithmetic(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
@ -198,10 +202,10 @@ pub fn impl_basic_arithmetic(
other_ty, other_ty,
ret_ty, ret_ty,
&[Operator::Add, Operator::Sub, Operator::Mult], &[Operator::Add, Operator::Sub, Operator::Mult],
) );
} }
/// Pow /// `Pow`
pub fn impl_pow( pub fn impl_pow(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
@ -209,10 +213,10 @@ pub fn impl_pow(
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Type,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
} }
/// BitOr, BitXor, BitAnd /// `BitOr`, `BitXor`, `BitAnd`
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop( impl_binop(
unifier, unifier,
@ -221,20 +225,20 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
&[ty], &[ty],
ty, ty,
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor], &[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
) );
} }
/// LShift, RShift /// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]); impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]);
} }
/// Div /// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]) impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]);
} }
/// FloorDiv /// `FloorDiv`
pub fn impl_floordiv( pub fn impl_floordiv(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
@ -242,10 +246,10 @@ pub fn impl_floordiv(
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Type,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
} }
/// Mod /// `Mod`
pub fn impl_mod( pub fn impl_mod(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
@ -253,25 +257,25 @@ pub fn impl_mod(
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Type,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
} }
/// UAdd, USub /// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]) impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]);
} }
/// Invert /// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]) impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]);
} }
/// Not /// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]) impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]);
} }
/// Lt, LtE, Gt, GtE /// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
impl_cmpop( impl_cmpop(
unifier, unifier,
@ -279,12 +283,12 @@ pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type,
ty, ty,
other_ty, other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE], &[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
) );
} }
/// Eq, NotEq /// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]) impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
} }
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {

View File

@ -43,15 +43,18 @@ pub struct TypeError {
} }
impl TypeError { impl TypeError {
#[must_use]
pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError { pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError {
TypeError { kind, loc } TypeError { kind, loc }
} }
#[must_use]
pub fn at(mut self, loc: Option<Location>) -> TypeError { pub fn at(mut self, loc: Option<Location>) -> TypeError {
self.loc = self.loc.or(loc); self.loc = self.loc.or(loc);
self self
} }
#[must_use]
pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError { pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError {
DisplayTypeError { err: self, unifier } DisplayTypeError { err: self, unifier }
} }
@ -64,8 +67,8 @@ pub struct DisplayTypeError<'a> {
fn loc_to_str(loc: Option<Location>) -> String { fn loc_to_str(loc: Option<Location>) -> String {
match loc { match loc {
Some(loc) => format!("(in {})", loc), Some(loc) => format!("(in {loc})"),
None => "".to_string(), None => String::new(),
} }
} }
@ -75,21 +78,20 @@ impl<'a> Display for DisplayTypeError<'a> {
let mut notes = Some(HashMap::new()); let mut notes = Some(HashMap::new());
match &self.err.kind { match &self.err.kind {
TooManyArguments { expected, got } => { TooManyArguments { expected, got } => {
write!(f, "Too many arguments. Expected {} but got {}", expected, got) write!(f, "Too many arguments. Expected {expected} but got {got}")
} }
MissingArgs(args) => { MissingArgs(args) => {
write!(f, "Missing arguments: {}", args) write!(f, "Missing arguments: {args}")
} }
UnknownArgName(name) => { UnknownArgName(name) => {
write!(f, "Unknown argument name: {}", name) write!(f, "Unknown argument name: {name}")
} }
IncorrectArgType { name, expected, got } => { IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!( write!(
f, f,
"Incorrect argument type for {}. Expected {}, but got {}", "Incorrect argument type for {name}. Expected {expected}, but got {got}"
name, expected, got
) )
} }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
@ -126,7 +128,7 @@ impl<'a> Display for DisplayTypeError<'a> {
); );
if let Some(loc) = loc { if let Some(loc) = loc {
result?; result?;
write!(f, " (in {})", loc)?; write!(f, " (in {loc})")?;
return Ok(()); return Ok(());
} }
result result
@ -136,12 +138,12 @@ impl<'a> Display for DisplayTypeError<'a> {
{ {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {} and {}", t1, t2) write!(f, "Tuple length mismatch: got {t1} and {t2}")
} }
_ => { _ => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Incompatible types: {} and {}", t1, t2) write!(f, "Incompatible types: {t1} and {t2}")
} }
} }
} }
@ -150,18 +152,17 @@ impl<'a> Display for DisplayTypeError<'a> {
write!(f, "Cannot assign to an element of a tuple") write!(f, "Cannot assign to an element of a tuple")
} else { } else {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "Cannot assign to field {} of {}, which is immutable", name, t) write!(f, "Cannot assign to field {name} of {t}, which is immutable")
} }
} }
NoSuchField(name, t) => { NoSuchField(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{}::{}` field/method does not exist", t, name) write!(f, "`{t}::{name}` field/method does not exist")
} }
TupleIndexOutOfBounds { index, len } => { TupleIndexOutOfBounds { index, len } => {
write!( write!(
f, f,
"Tuple index out of bounds. Got {} but tuple has only {} elements", "Tuple index out of bounds. Got {index} but tuple has only {len} elements"
index, len
) )
} }
RequiresTypeAnn => { RequiresTypeAnn => {
@ -172,13 +173,13 @@ impl<'a> Display for DisplayTypeError<'a> {
} }
}?; }?;
if let Some(loc) = self.err.loc { if let Some(loc) = self.err.loc {
write!(f, " at {}", loc)?; write!(f, " at {loc}")?;
} }
let notes = notes.unwrap(); let notes = notes.unwrap();
if !notes.is_empty() { if !notes.is_empty() {
write!(f, "\n\nNotes:")?; write!(f, "\n\nNotes:")?;
for line in notes.values() { for line in notes.values() {
write!(f, "\n {}", line)?; write!(f, "\n {line}")?;
} }
} }
Ok(()) Ok(())

View File

@ -65,20 +65,20 @@ struct NaiveFolder();
impl Fold<()> for NaiveFolder { impl Fold<()> for NaiveFolder {
type TargetU = Option<Type>; type TargetU = Option<Type>;
type Error = String; type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> { fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
Ok(None) Ok(None)
} }
} }
fn report_error<T>(msg: &str, location: Location) -> Result<T, String> { fn report_error<T>(msg: &str, location: Location) -> Result<T, String> {
Err(format!("{} at {}", msg, location)) Err(format!("{msg} at {location}"))
} }
impl<'a> Fold<()> for Inferencer<'a> { impl<'a> Fold<()> for Inferencer<'a> {
type TargetU = Option<Type>; type TargetU = Option<Type>;
type Error = String; type Error = String;
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> { fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
Ok(None) Ok(None)
} }
@ -138,7 +138,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
{ {
let top_level_defs = self.top_level.definitions.read(); let top_level_defs = self.top_level.definitions.read();
let mut naive_folder = NaiveFolder(); let mut naive_folder = NaiveFolder();
for handler in handlers.into_iter() { for handler in handlers {
let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } =
handler.node; handler.node;
let type_ = if let Some(type_) = type_ { let type_ = if let Some(type_) = type_ {
@ -226,65 +226,65 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
} }
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
for target in targets.iter_mut() { for target in &mut *targets {
if let ExprKind::Attribute { ctx, .. } = &mut target.node { if let ExprKind::Attribute { ctx, .. } = &mut target.node {
*ctx = ExprContext::Store; *ctx = ExprContext::Store;
} }
} }
if targets.iter().all(|t| matches!(t.node, ExprKind::Name { .. })) { if targets.iter().all(|t| matches!(t.node, ExprKind::Name { .. })) {
if let ast::StmtKind::Assign { targets, value, .. } = node.node { let ast::StmtKind::Assign { targets, value, .. } = node.node else {
let value = self.fold_expr(*value)?;
let value_ty = value.custom.unwrap();
let targets: Result<Vec<_>, _> = targets
.into_iter()
.map(|target| {
if let ExprKind::Name { id, ctx } = target.node {
self.defined_identifiers.insert(id);
let target_ty = if let Some(ty) = self.variable_mapping.get(&id)
{
*ty
} else {
let unifier = &mut self.unifier;
self.function_data
.resolver
.get_symbol_type(
unifier,
&self.top_level.definitions.read(),
self.primitives,
id,
)
.unwrap_or_else(|_| {
self.variable_mapping.insert(id, value_ty);
value_ty
})
};
let location = target.location;
self.unifier.unify(value_ty, target_ty).map(|_| Located {
location,
node: ExprKind::Name { id, ctx },
custom: Some(target_ty),
})
} else {
unreachable!()
}
})
.collect();
let loc = node.location;
let targets = targets
.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?;
return Ok(Located {
location: node.location,
node: ast::StmtKind::Assign {
targets,
value: Box::new(value),
type_comment: None,
config_comment: config_comment.clone(),
},
custom: None,
});
} else {
unreachable!() unreachable!()
} };
let value = self.fold_expr(*value)?;
let value_ty = value.custom.unwrap();
let targets: Result<Vec<_>, _> = targets
.into_iter()
.map(|target| {
if let ExprKind::Name { id, ctx } = target.node {
self.defined_identifiers.insert(id);
let target_ty = if let Some(ty) = self.variable_mapping.get(&id)
{
*ty
} else {
let unifier: &mut Unifier = self.unifier;
self.function_data
.resolver
.get_symbol_type(
unifier,
&self.top_level.definitions.read(),
self.primitives,
id,
)
.unwrap_or_else(|_| {
self.variable_mapping.insert(id, value_ty);
value_ty
})
};
let location = target.location;
self.unifier.unify(value_ty, target_ty).map(|()| Located {
location,
node: ExprKind::Name { id, ctx },
custom: Some(target_ty),
})
} else {
unreachable!()
}
})
.collect();
let loc = node.location;
let targets = targets
.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?;
return Ok(Located {
location: node.location,
node: ast::StmtKind::Assign {
targets,
value: Box::new(value),
type_comment: None,
config_comment: config_comment.clone(),
},
custom: None,
});
} }
for target in targets { for target in targets {
self.infer_pattern(target)?; self.infer_pattern(target)?;
@ -292,7 +292,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
fold::fold_stmt(self, node)? fold::fold_stmt(self, node)?
} }
ast::StmtKind::With { ref items, .. } => { ast::StmtKind::With { ref items, .. } => {
for item in items.iter() { for item in items {
if let Some(var) = &item.optional_vars { if let Some(var) = &item.optional_vars {
self.infer_pattern(var)?; self.infer_pattern(var)?;
} }
@ -302,20 +302,21 @@ impl<'a> Fold<()> for Inferencer<'a> {
_ => fold::fold_stmt(self, node)?, _ => fold::fold_stmt(self, node)?,
}; };
match &stmt.node { match &stmt.node {
ast::StmtKind::For { .. } => {} ast::StmtKind::AnnAssign { .. }
ast::StmtKind::Try { .. } => {} | ast::StmtKind::Break { .. }
| ast::StmtKind::Continue { .. }
| ast::StmtKind::Expr { .. }
| ast::StmtKind::For { .. }
| ast::StmtKind::Pass { .. }
| ast::StmtKind::Try { .. } => {}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
} }
ast::StmtKind::Assign { targets, value, .. } => { ast::StmtKind::Assign { targets, value, .. } => {
for target in targets.iter() { for target in targets {
self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?; self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?;
} }
} }
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break { .. }
| ast::StmtKind::Continue { .. }
| ast::StmtKind::Pass { .. } => {}
ast::StmtKind::Raise { exc, cause, .. } => { ast::StmtKind::Raise { exc, cause, .. } => {
if let Some(cause) = cause { if let Some(cause) = cause {
return report_error("raise ... from cause is not supported", cause.location); return report_error("raise ... from cause is not supported", cause.location);
@ -334,13 +335,13 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
} }
ast::StmtKind::With { items, .. } => { ast::StmtKind::With { items, .. } => {
for item in items.iter() { for item in items {
let ty = item.context_expr.custom.unwrap(); let ty = item.context_expr.custom.unwrap();
// if we can simply unify without creating new types... // if we can simply unify without creating new types...
let mut fast_path = false; let mut fast_path = false;
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
fast_path = true; fast_path = true;
if let Some(enter) = fields.get(&"__enter__".into()).cloned() { if let Some(enter) = fields.get(&"__enter__".into()).copied() {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) { if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) {
if !signature.args.is_empty() { if !signature.args.is_empty() {
return report_error( return report_error(
@ -368,7 +369,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
stmt.location, stmt.location,
); );
} }
if let Some(exit) = fields.get(&"__exit__".into()).cloned() { if let Some(exit) = fields.get(&"__exit__".into()).copied() {
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) { if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) {
if !signature.args.is_empty() { if !signature.args.is_empty() {
return report_error( return report_error(
@ -393,13 +394,13 @@ impl<'a> Fold<()> for Inferencer<'a> {
|| self.unifier.get_dummy_var().0, || self.unifier.get_dummy_var().0,
|var| var.custom.unwrap(), |var| var.custom.unwrap(),
), ),
vars: Default::default(), vars: HashMap::default(),
}); });
let enter = self.unifier.add_ty(enter); let enter = self.unifier.add_ty(enter);
let exit = TypeEnum::TFunc(FunSignature { let exit = TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: self.unifier.get_dummy_var().0, ret: self.unifier.get_dummy_var().0,
vars: Default::default(), vars: HashMap::default(),
}); });
let exit = self.unifier.add_ty(exit); let exit = self.unifier.add_ty(exit);
let mut fields = HashMap::new(); let mut fields = HashMap::new();
@ -489,7 +490,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
Err(e) => { Err(e) => {
return report_error( return report_error(
&format!("type error at identifier `{}` ({})", id, e), &format!("type error at identifier `{id}` ({e})"),
expr.location, expr.location,
); );
} }
@ -551,7 +552,7 @@ impl<'a> Inferencer<'a> {
Ok(()) Ok(())
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() { for elt in elts {
self.infer_pattern(elt)?; self.infer_pattern(elt)?;
} }
Ok(()) Ok(())
@ -637,7 +638,7 @@ impl<'a> Inferencer<'a> {
} }
let mut defined_identifiers = self.defined_identifiers.clone(); let mut defined_identifiers = self.defined_identifiers.clone();
for arg in args.args.iter() { for arg in &args.args {
let name = &arg.node.arg; let name = &arg.node.arg;
if !defined_identifiers.contains(name) { if !defined_identifiers.contains(name) {
defined_identifiers.insert(*name); defined_identifiers.insert(*name);
@ -649,7 +650,7 @@ impl<'a> Inferencer<'a> {
.map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)) .map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0))
.collect(); .collect();
let mut variable_mapping = self.variable_mapping.clone(); let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned()); variable_mapping.extend(fn_args.iter().copied());
let ret = self.unifier.get_dummy_var().0; let ret = self.unifier.get_dummy_var().0;
let mut new_context = Inferencer { let mut new_context = Inferencer {
@ -670,7 +671,7 @@ impl<'a> Inferencer<'a> {
.map(|(k, ty)| FuncArg { name: *k, ty: *ty, default_value: None }) .map(|(k, ty)| FuncArg { name: *k, ty: *ty, default_value: None })
.collect(), .collect(),
ret, ret,
vars: Default::default(), vars: HashMap::default(),
}; };
let body = new_context.fold_expr(body)?; let body = new_context.fold_expr(body)?;
new_context.unify(fun.ret, body.custom.unwrap(), &location)?; new_context.unify(fun.ret, body.custom.unwrap(), &location)?;
@ -739,7 +740,7 @@ impl<'a> Inferencer<'a> {
// iter should be a list of targets... // iter should be a list of targets...
// actually it should be an iterator of targets, but we don't have iter type for now // actually it should be an iterator of targets, but we don't have iter type for now
// if conditions should be bool // if conditions should be bool
for v in ifs.iter() { for v in &ifs {
new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?; new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?;
} }
@ -926,12 +927,12 @@ impl<'a> Inferencer<'a> {
} }
fn infer_identifier(&mut self, id: StrRef) -> InferenceResult { fn infer_identifier(&mut self, id: StrRef) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(&id) { Ok(if let Some(ty) = self.variable_mapping.get(&id) {
Ok(*ty) *ty
} else { } else {
let variable_mapping = &mut self.variable_mapping; let variable_mapping = &mut self.variable_mapping;
let unifier = &mut self.unifier; let unifier: &mut Unifier = self.unifier;
Ok(self self
.function_data .function_data
.resolver .resolver
.get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id)
@ -939,8 +940,8 @@ impl<'a> Inferencer<'a> {
let ty = unifier.get_dummy_var().0; let ty = unifier.get_dummy_var().0;
variable_mapping.insert(id, ty); variable_mapping.insert(id, ty);
ty ty
})) })
} })
} }
fn infer_constant(&mut self, constant: &ast::Constant, loc: &Location) -> InferenceResult { fn infer_constant(&mut self, constant: &ast::Constant, loc: &Location) -> InferenceResult {
@ -971,7 +972,7 @@ impl<'a> Inferencer<'a> {
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult { fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = self.unifier.get_dummy_var().0; let ty = self.unifier.get_dummy_var().0;
for t in elts.iter() { for t in elts {
self.unify(ty, t.custom.unwrap(), &t.location)?; self.unify(ty, t.custom.unwrap(), &t.location)?;
} }
Ok(self.unifier.add_ty(TypeEnum::TList { ty })) Ok(self.unifier.add_ty(TypeEnum::TList { ty }))
@ -992,14 +993,13 @@ impl<'a> Inferencer<'a> {
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path // just a fast path
match (fields.get(&attr), ctx == &ExprContext::Store) { match (fields.get(&attr), ctx == &ExprContext::Store) {
(Some((ty, true)), _) => Ok(*ty), (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => { (Some((_, false)), true) => {
report_error(&format!("Field `{}` is immutable", attr), value.location) report_error(&format!("Field `{attr}` is immutable"), value.location)
} }
(None, _) => { (None, _) => {
let t = self.unifier.stringify(ty); let t = self.unifier.stringify(ty);
report_error(&format!("`{}::{}` field/method does not exist", t, attr), value.location) report_error(&format!("`{t}::{attr}` field/method does not exist"), value.location)
}, },
} }
} else { } else {

View File

@ -97,8 +97,8 @@ impl From<i32> for RecordKey {
impl Display for RecordKey { impl Display for RecordKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
RecordKey::Str(s) => write!(f, "{}", s), RecordKey::Str(s) => write!(f, "{s}"),
RecordKey::Int(i) => write!(f, "{}", i), RecordKey::Int(i) => write!(f, "{i}"),
} }
} }
} }
@ -111,6 +111,7 @@ pub struct RecordField {
} }
impl RecordField { impl RecordField {
#[must_use]
pub fn new(ty: Type, mutable: bool, loc: Option<Location>) -> RecordField { pub fn new(ty: Type, mutable: bool, loc: Option<Location>) -> RecordField {
RecordField { ty, mutable, loc } RecordField { ty, mutable, loc }
} }
@ -185,6 +186,7 @@ pub enum TypeEnum {
} }
impl TypeEnum { impl TypeEnum {
#[must_use]
pub fn get_type_name(&self) -> &'static str { pub fn get_type_name(&self) -> &'static str {
match self { match self {
TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TRigidVar { .. } => "TRigidVar",
@ -220,6 +222,7 @@ impl Default for Unifier {
impl Unifier { impl Unifier {
/// Get an empty unifier /// Get an empty unifier
#[must_use]
pub fn new() -> Unifier { pub fn new() -> Unifier {
Unifier { Unifier {
unification_table: UnificationTable::new(), unification_table: UnificationTable::new(),
@ -252,6 +255,7 @@ impl Unifier {
} }
} }
#[must_use]
pub fn get_shared_unifier(&self) -> SharedUnifier { pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new(( Arc::new(Mutex::new((
self.unification_table.get_send(), self.unification_table.get_send(),
@ -261,7 +265,7 @@ impl Unifier {
} }
/// Register a type to the unifier. /// Register a type to the unifier.
/// Returns a key in the unification_table. /// Returns a key in the `unification_table`.
pub fn add_ty(&mut self, a: TypeEnum) -> Type { pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table.new_key(Rc::new(a)) self.unification_table.new_key(Rc::new(a))
} }
@ -294,6 +298,7 @@ impl Unifier {
} }
} }
#[must_use]
pub fn get_call_signature_immutable(&self, id: CallId) -> Option<FunSignature> { pub fn get_call_signature_immutable(&self, id: CallId) -> Option<FunSignature> {
let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap(); let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap();
if let TypeEnum::TFunc(sign) = &*self.get_ty_immutable(fun) { if let TypeEnum::TFunc(sign) = &*self.get_ty_immutable(fun) {
@ -307,11 +312,12 @@ impl Unifier {
self.unification_table.get_representative(ty) self.unification_table.get_representative(ty)
} }
/// Get the TypeEnum of a type. /// Get the `TypeEnum` of a type.
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> { pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
self.unification_table.probe_value(a).clone() self.unification_table.probe_value(a).clone()
} }
#[must_use]
pub fn get_ty_immutable(&self, a: Type) -> Rc<TypeEnum> { pub fn get_ty_immutable(&self, a: Type) -> Rc<TypeEnum> {
self.unification_table.probe_value_immutable(a).clone() self.unification_table.probe_value_immutable(a).clone()
} }
@ -435,7 +441,7 @@ impl Unifier {
.map(|params| { .map(|params| {
self.subst( self.subst(
ty, ty,
&zip(keys.iter().cloned(), params.iter().cloned()).collect(), &zip(keys.iter().copied(), params.iter().copied()).collect(),
) )
.unwrap_or(ty) .unwrap_or(ty)
}) })
@ -453,7 +459,7 @@ impl Unifier {
TRigidVar { .. } | TConstant { .. } => true, TRigidVar { .. } | TConstant { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
@ -461,7 +467,6 @@ impl Unifier {
// functions are instantiated for each call sites, so the function type can contain // functions are instantiated for each call sites, so the function type can contain
// type variables. // type variables.
TFunc { .. } => true, TFunc { .. } => true,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
} }
} }
@ -522,7 +527,7 @@ impl Unifier {
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?; })?;
} }
for (k, t) in kwargs.iter() { for (k, t) in kwargs {
if let Some(i) = required.iter().position(|v| v == k) { if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i); required.remove(i);
} }
@ -609,7 +614,7 @@ impl Unifier {
} }
(Some(fields1), Some(fields2)) => { (Some(fields1), Some(fields2)) => {
let mut new_fields: Mapping<_, _> = fields2.clone(); let mut new_fields: Mapping<_, _> = fields2.clone();
for (key, val1) in fields1.iter() { for (key, val1) in fields1 {
if let Some(val2) = fields2.get(key) { if let Some(val2) = fields2.get(key) {
self.unify_impl(val1.ty, val2.ty, false).map_err(|_| { self.unify_impl(val1.ty, val2.ty, false).map_err(|_| {
TypeError::new( TypeError::new(
@ -638,7 +643,7 @@ impl Unifier {
}; };
let intersection = self let intersection = self
.get_intersection(a, b) .get_intersection(a, b)
.map_err(|_| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))? .map_err(|()| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?
.unwrap(); .unwrap();
let range = if let TVar { range, .. } = &*self.get_ty(intersection) { let range = if let TVar { range, .. } = &*self.get_ty(intersection) {
range.clone() range.clone()
@ -677,7 +682,7 @@ impl Unifier {
} }
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = ty.len() as i32; let len = ty.len() as i32;
for (k, v) in fields.iter() { for (k, v) in fields {
match *k { match *k {
RecordKey::Int(i) => { RecordKey::Int(i) => {
if v.mutable { if v.mutable {
@ -706,10 +711,10 @@ impl Unifier {
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
for (k, v) in fields.iter() { for (k, v) in fields {
match *k { match *k {
RecordKey::Int(_) => { RecordKey::Int(_) => {
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))? self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?;
} }
RecordKey::Str(_) => { RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
@ -767,7 +772,7 @@ impl Unifier {
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
for (k, field) in map.iter() { for (k, field) in map {
match *k { match *k {
RecordKey::Str(s) => { RecordKey::Str(s) => {
let (ty, mutable) = fields.get(&s).copied().ok_or_else(|| { let (ty, mutable) = fields.get(&s).copied().ok_or_else(|| {
@ -799,7 +804,7 @@ impl Unifier {
(TVar { fields: Some(map), range, .. }, TVirtual { ty }) => { (TVar { fields: Some(map), range, .. }, TVirtual { ty }) => {
let ty = self.get_ty(*ty); let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() { if let TObj { fields, .. } = ty.as_ref() {
for (k, field) in map.iter() { for (k, field) in map {
match *k { match *k {
RecordKey::Str(s) => { RecordKey::Str(s) => {
let (ty, _) = fields.get(&s).copied().ok_or_else(|| { let (ty, _) = fields.get(&s).copied().ok_or_else(|| {
@ -866,7 +871,7 @@ impl Unifier {
(TCall(calls1), TCall(calls2)) => { (TCall(calls1), TCall(calls2)) => {
// we do not unify individual calls, instead we defer until the unification wtih a // we do not unify individual calls, instead we defer until the unification wtih a
// function definition. // function definition.
let calls = calls1.iter().chain(calls2.iter()).cloned().collect(); let calls = calls1.iter().chain(calls2.iter()).copied().collect();
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
self.unification_table.set_value(b, Rc::new(TCall(calls))); self.unification_table.set_value(b, Rc::new(TCall(calls)));
} }
@ -879,7 +884,7 @@ impl Unifier {
.rev() .rev()
.collect(); .collect();
// we unify every calls to the function signature. // we unify every calls to the function signature.
for c in calls.iter() { for c in calls {
let call = self.calls[c.0].clone(); let call = self.calls[c.0].clone();
self.unify_call(&call, b, signature, &required)?; self.unify_call(&call, b, signature, &required)?;
} }
@ -912,9 +917,9 @@ impl Unifier {
_ => { _ => {
if swapped { if swapped {
return self.incompatible_types(a, b); return self.incompatible_types(a, b);
} else {
self.unify_impl(b, a, true)?;
} }
self.unify_impl(b, a, true)?;
} }
} }
Ok(()) Ok(())
@ -934,7 +939,7 @@ impl Unifier {
ty, ty,
&mut |id| { &mut |id| {
top_level.as_ref().map_or_else( top_level.as_ref().map_or_else(
|| format!("{}", id), || format!("{id}"),
|top_level| { |top_level| {
if let TopLevelDef::Class { name, .. } = if let TopLevelDef::Class { name, .. } =
&*top_level.definitions.read()[id].read() &*top_level.definitions.read()[id].read()
@ -946,7 +951,7 @@ impl Unifier {
}, },
) )
}, },
&mut |id| format!("typevar{}", id), &mut |id| format!("typevar{id}"),
notes, notes,
) )
} }
@ -989,7 +994,7 @@ impl Unifier {
if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id) if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id)
{ {
// just in case if there is any cyclic dependency // just in case if there is any cyclic dependency
notes.as_mut().unwrap().insert(*id, "".into()); notes.as_mut().unwrap().insert(*id, String::new());
let body = format!( let body = format!(
"{} ∈ {{{}}}", "{} ∈ {{{}}}",
n, n,
@ -1022,15 +1027,15 @@ impl Unifier {
} }
TypeEnum::TObj { obj_id, params, .. } => { TypeEnum::TObj { obj_id, params, .. } => {
let name = obj_to_name(obj_id.0); let name = obj_to_name(obj_id.0);
if !params.is_empty() { if params.is_empty() {
name
} else {
let params = params let params = params
.iter() .iter()
.map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
// sort to preserve order // sort to preserve order
let mut params = params.sorted(); let mut params = params.sorted();
format!("{}[{}]", name, params.join(", ")) format!("{}[{}]", name, params.join(", "))
} else {
name
} }
} }
TypeEnum::TCall { .. } => "call".to_owned(), TypeEnum::TCall { .. } => "call".to_owned(),
@ -1056,7 +1061,7 @@ impl Unifier {
}) })
.join(", "); .join(", ");
let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes);
format!("fn[[{}], {}]", params, ret) format!("fn[[{params}], {ret}]")
} }
} }
} }
@ -1066,7 +1071,7 @@ impl Unifier {
let table = &mut self.unification_table; let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone(); let ty_b = table.probe_value(b).clone();
table.unify(a, b); table.unify(a, b);
table.set_value(a, ty_b) table.set_value(a, ty_b);
} }
fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> { fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
@ -1079,7 +1084,7 @@ impl Unifier {
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = true; let mut instantiated = true;
let mut vars = Vec::new(); let mut vars = Vec::new();
for (k, v) in fun.vars.iter() { for (k, v) in &fun.vars {
if let TypeEnum::TVar { id, name, loc, range, .. } = if let TypeEnum::TVar { id, name, loc, range, .. } =
self.unification_table.probe_value(*v).as_ref() self.unification_table.probe_value(*v).as_ref()
{ {
@ -1134,7 +1139,7 @@ impl Unifier {
// should be safe to not implement the substitution for those variants. // should be safe to not implement the substitution for those variants.
match &*ty { match &*ty {
TypeEnum::TRigidVar { .. } => None, TypeEnum::TRigidVar { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).cloned(), TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() { for (i, t) in ty.iter().enumerate() {
@ -1219,7 +1224,7 @@ impl Unifier {
K: std::hash::Hash + Eq + Clone, K: std::hash::Hash + Eq + Clone,
{ {
let mut map2 = None; let mut map2 = None;
for (k, v) in map.iter() { for (k, v) in map {
if let Some(v1) = self.subst_impl(*v, mapping, cache) { if let Some(v1) = self.subst_impl(*v, mapping, cache) {
if map2.is_none() { if map2.is_none() {
map2 = Some(map.clone()); map2 = Some(map.clone());
@ -1240,7 +1245,7 @@ impl Unifier {
K: std::hash::Hash + Eq + Clone, K: std::hash::Hash + Eq + Clone,
{ {
let mut map2 = None; let mut map2 = None;
for (k, (v, mutability)) in map.iter() { for (k, (v, mutability)) in map {
if let Some(v1) = self.subst_impl(*v, mapping, cache) { if let Some(v1) = self.subst_impl(*v, mapping, cache) {
if map2.is_none() { if map2.is_none() {
map2 = Some(map.clone()); map2 = Some(map.clone());
@ -1296,7 +1301,7 @@ impl Unifier {
if range.is_empty() { if range.is_empty() {
Ok(Some(a)) Ok(Some(a))
} else { } else {
for v in range.iter() { for v in range {
let result = self.get_intersection(a, *v); let result = self.get_intersection(a, *v);
if let Ok(result) = result { if let Ok(result) = result {
return Ok(result.or(Some(a))); return Ok(result.or(Some(a)));
@ -1338,7 +1343,7 @@ impl Unifier {
if range.is_empty() { if range.is_empty() {
return Ok(None); return Ok(None);
} }
for t in range.iter() { for t in range {
let result = self.get_intersection(*t, b); let result = self.get_intersection(*t, b);
if let Ok(result) = result { if let Ok(result) = result {
return Ok(result); return Ok(result);

View File

@ -393,7 +393,7 @@ fn main() {
let threads = (0..threads) let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{}", i), 64))) .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{}", i), 64)))
.collect(); .collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);