From 352831b2caadeb4d3725954fce36b79d5be3cab5 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sun, 13 Feb 2022 22:39:24 +0800 Subject: [PATCH 1/5] nac3core: removed legacy location definition --- nac3artiq/src/symbol_resolver.rs | 5 ---- nac3core/src/codegen/test.rs | 5 ---- nac3core/src/lib.rs | 1 - nac3core/src/location.rs | 28 ------------------- nac3core/src/symbol_resolver.rs | 3 +- nac3core/src/toplevel/test.rs | 5 ---- .../src/typecheck/type_inferencer/test.rs | 5 ---- nac3standalone/src/basic_symbol_resolver.rs | 5 ---- 8 files changed, 1 insertion(+), 56 deletions(-) delete mode 100644 nac3core/src/location.rs diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 3642763..2ce335a 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,7 +1,6 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::{CodeGenContext, CodeGenerator}, - location::Location, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ @@ -989,10 +988,6 @@ impl SymbolResolver for Resolver { }) } - fn get_symbol_location(&self, _: StrRef) -> Option { - unimplemented!() - } - fn get_identifier_def(&self, id: StrRef) -> Option { { let id_to_def = self.0.id_to_def.read(); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ea0ccb7..e174f5f 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -3,7 +3,6 @@ use crate::{ concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry, }, - location::Location, symbol_resolver::{SymbolResolver, ValueEnum}, toplevel::{ composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, @@ -58,10 +57,6 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_symbol_location(&self, _: StrRef) -> Option { - unimplemented!() - } - fn get_identifier_def(&self, id: StrRef) -> Option { self.id_to_def.read().get(&id).cloned() } diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 10e1d25..bff6b5c 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,7 +2,6 @@ #![allow(dead_code)] pub mod codegen; -pub mod location; pub mod symbol_resolver; pub mod toplevel; pub mod typecheck; diff --git a/nac3core/src/location.rs b/nac3core/src/location.rs deleted file mode 100644 index f096852..0000000 --- a/nac3core/src/location.rs +++ /dev/null @@ -1,28 +0,0 @@ -use nac3parser::ast; -use std::vec::Vec; - -#[derive(Clone, Copy, PartialEq)] -pub struct FileID(u32); - -#[derive(Clone, Copy, PartialEq)] -pub enum Location { - CodeRange(FileID, ast::Location), - Builtin, -} - -#[derive(Default)] -pub struct FileRegistry { - files: Vec, -} - -impl FileRegistry { - pub fn add_file(&mut self, path: &str) -> FileID { - let index = self.files.len() as u32; - self.files.push(path.to_owned()); - FileID(index) - } - - pub fn query_file(&self, id: FileID) -> &str { - &self.files[id.0 as usize] - } -} diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 8d537bc..25ad548 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -13,7 +13,7 @@ use crate::{ typedef::{Type, Unifier}, }, }; -use crate::{location::Location, typecheck::typedef::TypeEnum}; +use crate::typecheck::typedef::TypeEnum; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use itertools::{chain, izip}; use nac3parser::ast::{Expr, StrRef}; @@ -113,7 +113,6 @@ pub trait SymbolResolver { ctx: &mut CodeGenContext<'ctx, 'a>, ) -> Option>; - fn get_symbol_location(&self, str: StrRef) -> Option; fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; fn get_string_id(&self, s: &str) -> i32; // handle function call etc. diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 8021bf9..99a6ed2 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -1,6 +1,5 @@ use crate::{ codegen::CodeGenContext, - location::Location, symbol_resolver::{SymbolResolver, ValueEnum}, toplevel::DefinitionId, typecheck::{ @@ -62,10 +61,6 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_symbol_location(&self, _: StrRef) -> Option { - unimplemented!() - } - fn get_identifier_def(&self, id: StrRef) -> Option { self.0.id_to_def.lock().get(&id).cloned() } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index d8156ef..e845985 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -2,7 +2,6 @@ use super::super::typedef::*; use super::*; use crate::{ codegen::CodeGenContext, - location::Location, symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, }; @@ -41,10 +40,6 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_symbol_location(&self, _: StrRef) -> Option { - unimplemented!() - } - fn get_identifier_def(&self, id: StrRef) -> Option { self.id_to_def.get(&id).cloned() } diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 6e886fd..0b079c8 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -1,6 +1,5 @@ use nac3core::{ codegen::CodeGenContext, - location::Location, symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ @@ -64,10 +63,6 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_symbol_location(&self, _: StrRef) -> Option { - unimplemented!() - } - fn get_identifier_def(&self, id: StrRef) -> Option { self.0.id_to_def.lock().get(&id).cloned() } -- 2.44.1 From d9cb506f6a759b5567ac986dbf8a690a78fce9c9 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 21 Feb 2022 17:52:34 +0800 Subject: [PATCH 2/5] nac3core: refactored for better error messages --- nac3artiq/demo/demo.py | 19 + nac3artiq/src/codegen.rs | 67 +- nac3artiq/src/lib.rs | 17 +- nac3artiq/src/symbol_resolver.rs | 59 +- nac3core/src/codegen/concrete_type.rs | 5 +- nac3core/src/codegen/expr.rs | 233 +++---- nac3core/src/codegen/generator.rs | 30 +- nac3core/src/codegen/irrt/mod.rs | 30 +- nac3core/src/codegen/mod.rs | 34 +- nac3core/src/codegen/stmt.rs | 148 +++-- nac3core/src/codegen/test.rs | 8 +- nac3core/src/symbol_resolver.rs | 97 +-- nac3core/src/toplevel/builtins.rs | 429 ++++++------ nac3core/src/toplevel/composer.rs | 269 +++++--- nac3core/src/toplevel/helper.rs | 20 +- nac3core/src/toplevel/mod.rs | 10 +- ...el__test__test_analyze__generic_class.snap | 10 +- ...t__test_analyze__inheritance_override.snap | 16 +- ...est__test_analyze__list_tuple_generic.snap | 12 +- ...__toplevel__test__test_analyze__self1.snap | 14 +- ...t__test_analyze__simple_class_compose.snap | 18 +- nac3core/src/toplevel/test.rs | 18 +- nac3core/src/toplevel/type_annotation.rs | 27 +- nac3core/src/typecheck/magic_methods.rs | 67 +- nac3core/src/typecheck/mod.rs | 1 + nac3core/src/typecheck/type_error.rs | 177 +++++ nac3core/src/typecheck/type_inferencer/mod.rs | 96 +-- .../src/typecheck/type_inferencer/test.rs | 50 +- nac3core/src/typecheck/typedef/mod.rs | 619 ++++++++++-------- nac3core/src/typecheck/typedef/test.rs | 187 +++--- nac3core/src/typecheck/unification_table.rs | 11 + nac3standalone/src/basic_symbol_resolver.rs | 4 +- nac3standalone/src/main.rs | 4 +- 33 files changed, 1619 insertions(+), 1187 deletions(-) create mode 100644 nac3core/src/typecheck/type_error.rs diff --git a/nac3artiq/demo/demo.py b/nac3artiq/demo/demo.py index a9abc18..f1a6502 100644 --- a/nac3artiq/demo/demo.py +++ b/nac3artiq/demo/demo.py @@ -1,4 +1,13 @@ from min_artiq import * +from numpy import int32, int64 + +@extern +def output_int(x: int32): + ... + + +class InexistingException(Exception): + pass @nac3 class Demo: @@ -11,6 +20,16 @@ class Demo: self.led0 = TTLOut(self.core, 18) self.led1 = TTLOut(self.core, 19) + @kernel + def test(self): + a = (1, True) + a[0]() + + @kernel + def test2(self): + a = (1, True) + output_int(int32(a)) + @kernel def run(self): self.core.reset() diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index a799fa4..6aca71d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -64,10 +64,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, - ) -> Option> { - let result = gen_call(self, ctx, obj, fun, params); + ) -> Result>, String> { + let result = gen_call(self, ctx, obj, fun, params)?; if let Some(end) = self.end.clone() { - let old_end = self.gen_expr(ctx, &end).unwrap().to_basic_value_enum(ctx, self); + let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(ctx, self); let now = self.timeline.emit_now_mu(ctx); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); @@ -83,21 +83,21 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { .try_as_basic_value() .left() .unwrap(); - let end_store = self.gen_store_target(ctx, &end); + let end_store = self.gen_store_target(ctx, &end)?; ctx.builder.build_store(end_store, max); } if let Some(start) = self.start.clone() { - let start_val = self.gen_expr(ctx, &start).unwrap().to_basic_value_enum(ctx, self); + let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(ctx, self); self.timeline.emit_at_mu(ctx, start_val); } - result + Ok(result) } fn gen_with<'ctx, 'a>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, - ) { + ) -> Result<(), String> { if let StmtKind::With { items, body, .. } = &stmt.node { if items.len() == 1 && items[0].optional_vars.is_none() { let item = &items[0]; @@ -119,7 +119,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let old_start = self.start.take(); let old_end = self.end.take(); let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx, self) + self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(ctx, self) } else { self.timeline.emit_now_mu(ctx) }; @@ -130,7 +130,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { // the LLVM Context. // The name is guaranteed to be unique as users cannot use this as variable // name. - self.start = old_start.clone().or_else(|| { + self.start = old_start.clone().map_or_else(|| { let start = format!("with-{}-start", self.name_counter).into(); let start_expr = Located { // location does not matter at this point @@ -138,10 +138,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; - let start = self.gen_store_target(ctx, &start_expr); + let start = self.gen_store_target(ctx, &start_expr)?; ctx.builder.build_store(start, now); - Some(start_expr) - }); + Ok(Some(start_expr)) as Result<_, String> + }, |v| Ok(Some(v)))?; let end = format!("with-{}-end", self.name_counter).into(); let end_expr = Located { // location does not matter at this point @@ -149,11 +149,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; - let end = self.gen_store_target(ctx, &end_expr); + let end = self.gen_store_target(ctx, &end_expr)?; ctx.builder.build_store(end, now); self.end = Some(end_expr); self.name_counter += 1; - gen_block(self, ctx, body.iter()); + gen_block(self, ctx, body.iter())?; let current = ctx.builder.get_insert_block().unwrap(); // if the current block is terminated, move before the terminator // we want to set the timeline before reaching the terminator @@ -171,7 +171,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { // set duration let end_expr = self.end.take().unwrap(); let end_val = - self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx, self); + self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(ctx, self); // inside a sequential block if old_start.is_none() { @@ -180,7 +180,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { // inside a parallel block, should update the outer max now_mu if let Some(old_end) = &old_end { let outer_end_val = - self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx, self); + self.gen_expr(ctx, old_end)?.unwrap().to_basic_value_enum(ctx, self); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); @@ -196,7 +196,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { .try_as_basic_value() .left() .unwrap(); - let outer_end = self.gen_store_target(ctx, old_end); + let outer_end = self.gen_store_target(ctx, old_end)?; ctx.builder.build_store(outer_end, max); } self.start = old_start; @@ -204,29 +204,29 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { if reset_position { ctx.builder.position_at_end(current); } - return; + return Ok(()); } else if id == &"sequential".into() { let start = self.start.take(); for stmt in body.iter() { - self.gen_stmt(ctx, stmt); + self.gen_stmt(ctx, stmt)?; if ctx.is_terminated() { break; } } self.start = start; - return + return Ok(()); } } } // not parallel/sequential - gen_with(self, ctx, stmt); + gen_with(self, ctx, stmt) } else { unreachable!() } } } -fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec) { +fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec) -> Result<(), String> { use nac3core::typecheck::typedef::TypeEnum::*; let int32 = ctx.primitives.int32; @@ -249,24 +249,25 @@ fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: & } else if ctx.unifier.unioned(ty, none) { buffer.push(b'n'); } else { - let ty = ctx.unifier.get_ty(ty); - match &*ty { + let ty_enum = ctx.unifier.get_ty(ty); + match &*ty_enum { TTuple { ty } => { buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { - gen_rpc_tag(ctx, *ty, buffer); + gen_rpc_tag(ctx, *ty, buffer)?; } } TList { ty } => { buffer.push(b'l'); - gen_rpc_tag(ctx, *ty, buffer); + gen_rpc_tag(ctx, *ty, buffer)?; } // we should return an error, this will be fixed after improving error message // as this requires returning an error during codegen - _ => unimplemented!(), + _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } + Ok(()) } fn rpc_codegen_callback_fn<'ctx, 'a>( @@ -275,7 +276,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( fun: (&FunSignature, DefinitionId), args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, -) -> Option> { +) -> Result>, String> { let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic); let size_type = generator.get_size_type(ctx.ctx); let int8 = ctx.ctx.i8_type(); @@ -289,10 +290,10 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( tag.push(b'O'); } for arg in fun.0.args.iter() { - gen_rpc_tag(ctx, arg.ty, &mut tag); + gen_rpc_tag(ctx, arg.ty, &mut tag)?; } tag.push(b':'); - gen_rpc_tag(ctx, fun.0.ret, &mut tag); + gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); @@ -432,7 +433,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); - return None + return Ok(None) } let prehead_bb = ctx.builder.get_insert_block().unwrap(); @@ -474,7 +475,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( ctx.builder.position_at_end(tail_bb); - if need_load { + Ok(if need_load { let result = ctx.builder.build_load(slot, "rpc.result"); ctx.builder.build_call( stackrestore, @@ -484,7 +485,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( Some(result) } else { Some(slot.into()) - } + }) } pub fn rpc_codegen_callback() -> Arc { diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 085151e..5707551 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -203,7 +203,7 @@ impl Nac3 { let fun_ty = if method_name.is_empty() { base_ty } else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) { - match fields.borrow().get(&(*method_name).into()) { + match fields.get(&(*method_name).into()) { Some(t) => t.0, None => return Some( format!("object launching kernel does not have method `{}`", method_name) @@ -213,8 +213,7 @@ impl Nac3 { return Some("cannot launch kernel by calling a non-callable".into()) }; - if let TypeEnum::TFunc(sig) = &*unifier.get_ty(fun_ty) { - let FunSignature { args, .. } = &*sig.borrow(); + if let TypeEnum::TFunc(FunSignature { args, .. }) = &*unifier.get_ty(fun_ty) { if arg_names.len() > args.len() { return Some(format!( "launching kernel function with too many arguments (expect {}, found {})", @@ -243,7 +242,7 @@ impl Nac3 { }; if let Err(e) = unifier.unify(in_ty, *ty) { return Some(format!( - "type error ({}) at parameter #{} when calling kernel function", e, i + "type error ({}) at parameter #{} when calling kernel function", e.to_display(unifier).to_string(), i )); } } @@ -281,7 +280,7 @@ impl Nac3 { vars: HashMap::new(), }, Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { - Some(time_fns.emit_now_mu(ctx)) + Ok(Some(time_fns.emit_now_mu(ctx))) }))), ), ( @@ -298,7 +297,7 @@ impl Nac3 { Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| { let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); time_fns.emit_at_mu(ctx, arg); - None + Ok(None) }))), ), ( @@ -315,7 +314,7 @@ impl Nac3 { Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| { let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); time_fns.emit_delay_mu(ctx, arg); - None + Ok(None) }))), ), ]; @@ -536,7 +535,7 @@ impl Nac3 { let (name, def_id, ty) = composer .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .map_err(|e| { - exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)) + exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure\n----------\n{}", e)) })?; match &stmt.node { @@ -637,7 +636,7 @@ impl Nac3 { // report error of __modinit__ separately if !e.contains("__nac3_synthesized_modinit__") { return Err(exceptions::PyRuntimeError::new_err( - format!("nac3 compilation failure: {}", e) + format!("nac3 compilation failure: \n----------\n{}", e) )); } else { let msg = Self::report_modinit( diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 2ce335a..db513f8 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -15,7 +15,6 @@ use pyo3::{ PyAny, PyObject, PyResult, Python, }; use std::{ - cell::RefCell, collections::{HashMap, HashSet}, sync::Arc, }; @@ -208,7 +207,7 @@ impl InnerResolver { ty = match unifier.unify(ty, b) { Ok(_) => ty, Err(e) => return Ok(Err(format!( - "inhomogeneous type ({}) at element #{} of the list", e, i + "inhomogeneous type ({}) at element #{} of the list", e.to_display(unifier).to_string(), i ))) }; } @@ -246,7 +245,7 @@ impl InnerResolver { Ok(Ok((primitives.exception, true))) }else if ty_id == self.primitive_ids.list { // do not handle type var param and concrete check here - let var = unifier.get_fresh_var().0; + let var = unifier.get_dummy_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.tuple { @@ -266,8 +265,7 @@ impl InnerResolver { Ok(Ok({ let ty = TypeEnum::TObj { obj_id: *object_id, - params: RefCell::new({ - type_vars + params: type_vars .iter() .map(|x| { if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { @@ -276,16 +274,15 @@ impl InnerResolver { unreachable!() } }) - .collect() - }), - fields: RefCell::new({ + .collect(), + fields: { let mut res = methods .iter() .map(|(iden, ty, _)| (*iden, (*ty, false))) .collect::>(); res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); res - }), + }, }; // here also false, later instantiation use python object to check compatible (unifier.add_ty(ty), false) @@ -295,6 +292,7 @@ impl InnerResolver { unreachable!("function type is not supported, should not be queried") } } else if ty_ty_id == self.primitive_ids.typevar { + let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap(); let constraint_types = { let constraints = pyty.getattr("__constraints__").unwrap(); let mut result: Vec = vec![]; @@ -322,7 +320,7 @@ impl InnerResolver { } result }; - let res = unifier.get_fresh_var_with_range(&constraint_types).0; + let res = unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; Ok(Ok((res, true))) } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 @@ -388,7 +386,6 @@ impl InnerResolver { } TypeEnum::TObj { params, obj_id, .. } => { let subst = { - let params = &*params.borrow(); if params.len() != args.len() { return Ok(Err(format!( "for class #{}, expect {} type parameters, got {}.", @@ -456,14 +453,16 @@ impl InnerResolver { Ok(Ok(( { let ty = TypeEnum::TVirtual { - ty: unifier.get_fresh_var().0, + ty: unifier.get_dummy_var().0, }; unifier.add_ty(ty) }, false, ))) } else { - Ok(Err("unknown type".into())) + let str_fn = pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); + let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); + Ok(Err(format!("{} is not supported in nac3 (did you forgot to put @nac3 annotation?)", str_repr))) } } @@ -510,8 +509,8 @@ impl InnerResolver { if len == 0 { assert!(matches!( &*unifier.get_ty(extracted_ty), - TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. } - if range.borrow().is_empty() + TypeEnum::TVar { fields: None, range, .. } + if range.is_empty() )); Ok(Ok(extracted_ty)) } else { @@ -520,7 +519,7 @@ impl InnerResolver { match actual_ty { Ok(t) => match unifier.unify(*ty, t) { Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList{ ty: *ty }))), - Err(e) => Ok(Err(format!("type error ({}) for the list", e))), + Err(e) => Ok(Err(format!("type error ({}) for the list", e.to_display(unifier).to_string()))), } Err(e) => Ok(Err(e)), } @@ -537,19 +536,18 @@ impl InnerResolver { } (TypeEnum::TObj { params, fields, .. }, false) => { let var_map = params - .borrow() .iter() .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) { + if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) { assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(&range.borrow()).0) + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) } else { unreachable!() } }) .collect::>(); // loop through non-function fields of the class to get the instantiated value - for field in fields.borrow().iter() { + for field in fields.iter() { let name: String = (*field.0).into(); if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { continue; @@ -566,7 +564,7 @@ impl InnerResolver { if let Err(e) = unifier.unify(ty, field_ty) { // field type mismatch return Ok(Err(format!( - "error when getting type of field `{}` ({})", name, e + "error when getting type of field `{}` ({})", name, e.to_display(unifier).to_string() ))); } } @@ -988,18 +986,19 @@ impl SymbolResolver for Resolver { }) } - fn get_identifier_def(&self, id: StrRef) -> Option { + fn get_identifier_def(&self, id: StrRef) -> Result { { let id_to_def = self.0.id_to_def.read(); - id_to_def.get(&id).cloned() + id_to_def.get(&id).cloned().ok_or_else(|| "".to_string()) } - .or_else(|| { - let py_id = self.0.name_to_pyid.get(&id); - let result = py_id.and_then(|id| self.0.pyid_to_def.read().get(id).copied()); - if let Some(result) = &result { - self.0.id_to_def.write().insert(id, *result); - } - result + .or_else(|_| { + let py_id = self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?; + let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or(format!( + "`{}` is not registered in nac3, did you forgot to add @nac3?", + id + ))?; + self.0.id_to_def.write().insert(id, result); + Ok(result) }) } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index fd13afa..36acaaa 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -157,7 +157,6 @@ impl ConcreteTypeStore { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { obj_id: *obj_id, fields: fields - .borrow() .iter() .filter_map(|(name, ty)| { // here we should not have type vars, but some partial instantiated @@ -171,7 +170,6 @@ impl ConcreteTypeStore { }) .collect(), params: params - .borrow() .iter() .map(|(id, ty)| { (*id, self.from_unifier_type(unifier, primitives, *ty, cache)) @@ -182,7 +180,6 @@ impl ConcreteTypeStore { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, TypeEnum::TFunc(signature) => { - let signature = signature.borrow(); self.from_signature(unifier, primitives, &*signature, cache) } _ => unreachable!(), @@ -210,7 +207,7 @@ impl ConcreteTypeStore { return if let Some(ty) = ty { *ty } else { - *ty = Some(unifier.get_fresh_var().0); + *ty = Some(unifier.get_dummy_var().0); ty.unwrap() }; } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index badd9f4..11b8554 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -31,7 +31,7 @@ pub fn get_subst_key( let mut vars = obj .map(|ty| { if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { - params.borrow().clone() + params.clone() } else { unreachable!() } @@ -40,7 +40,7 @@ pub fn get_subst_key( vars.extend(fun_vars.iter()); let sorted = vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); sorted - .map(|id| unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())) + .map(|id| unifier.internal_stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None)) .join(", ") } @@ -352,7 +352,7 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( signature: &FunSignature, def: &TopLevelDef, params: Vec<(Option, ValueEnum<'ctx>)>, -) -> BasicValueEnum<'ctx> { +) -> Result, String> { match def { TopLevelDef::Class { methods, .. } => { // TODO: what about other fields that require alloca? @@ -374,9 +374,9 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( Some((signature.ret, zelf.into())), (&sign, fun_id), params, - ); + )?; } - zelf + Ok(zelf) } _ => unreachable!(), } @@ -387,7 +387,7 @@ pub fn gen_func_instance<'ctx, 'a>( obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, &mut TopLevelDef, String), id: usize, -) -> String { +) -> Result { if let ( sign, TopLevelDef::Function { @@ -396,56 +396,57 @@ pub fn gen_func_instance<'ctx, 'a>( key, ) = fun { - instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { - let symbol = format!("{}.{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key.clone(), symbol.clone()); - let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id)); - let instance = instance_to_stmt.get(&key).unwrap(); + if let Some(sym) = instance_to_symbol.get(&key) { + return Ok(sym.clone()); + } + let symbol = format!("{}.{}", name, instance_to_symbol.len()); + instance_to_symbol.insert(key.clone(), symbol.clone()); + let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id)); + let instance = instance_to_stmt.get(&key).unwrap(); - let mut store = ConcreteTypeStore::new(); - let mut cache = HashMap::new(); + let mut store = ConcreteTypeStore::new(); + let mut cache = HashMap::new(); - let subst = sign - .vars - .iter() - .map(|(id, ty)| { - ( - *instance.subst.get(id).unwrap(), - store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, *ty, &mut cache), - ) - }) - .collect(); + let subst = sign + .vars + .iter() + .map(|(id, ty)| { + ( + *instance.subst.get(id).unwrap(), + store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, *ty, &mut cache), + ) + }) + .collect(); - let mut signature = - store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); + let mut signature = + store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); - if let Some(obj) = &obj { - let zelf = - store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); - if let ConcreteTypeEnum::TFunc { args, .. } = &mut signature { - args.insert( - 0, - ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, - ) - } else { - unreachable!() - } + if let Some(obj) = &obj { + let zelf = + store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); + if let ConcreteTypeEnum::TFunc { args, .. } = &mut signature { + args.insert( + 0, + ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, + ) + } else { + unreachable!() } - let signature = store.add_cty(signature); + } + let signature = store.add_cty(signature); - ctx.registry.add_task(CodeGenTask { - symbol_name: symbol.clone(), - body: instance.body.clone(), - resolver: resolver.as_ref().unwrap().clone(), - calls: instance.calls.clone(), - subst, - signature, - store, - unifier_index: instance.unifier_id, - id, - }); - symbol - }) + ctx.registry.add_task(CodeGenTask { + symbol_name: symbol.clone(), + body: instance.body.clone(), + resolver: resolver.as_ref().unwrap().clone(), + calls: instance.calls.clone(), + subst, + signature, + store, + unifier_index: instance.unifier_id, + id, + }); + Ok(symbol) } else { unreachable!() } @@ -457,9 +458,8 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, -) -> Option> { +) -> Result>, String> { let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); - let id; let key; let param_vals; @@ -492,7 +492,6 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( if let Some(obj) = &obj { real_params.insert(0, obj.1.clone()); } - let static_params = real_params .iter() .enumerate() @@ -530,16 +529,16 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( .into_iter() .map(|p| p.to_basic_value_enum(ctx, generator)) .collect_vec(); - instance_to_symbol.get(&key).cloned() + instance_to_symbol.get(&key).cloned().ok_or_else(|| "".into()) } TopLevelDef::Class { .. } => { - return Some(generator.gen_constructor(ctx, fun.0, &*def, params)) + return Ok(Some(generator.gen_constructor(ctx, fun.0, &*def, params)?)) } } } - .unwrap_or_else(|| { + .or_else(|_: String| { generator.gen_func_instance(ctx, obj.clone(), (fun.0, &mut *definition.write(), key), id) - }); + })?; let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| { let mut args = fun.0.args.clone(); if let Some(obj) = &obj { @@ -554,8 +553,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( }; ctx.module.add_function(&symbol, fun_ty, None) }); - - ctx.build_call_or_invoke(fun_val, ¶m_vals, "call") + Ok(ctx.build_call_or_invoke(fun_val, ¶m_vals, "call")) } pub fn destructure_range<'ctx, 'a>( @@ -607,7 +605,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, expr: &Expr>, -) -> BasicValueEnum<'ctx> { +) -> Result, String> { if let ExprKind::ListComp { elt, generators } = &expr.node { let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let test_bb = ctx.ctx.append_basic_block(current, "test"); @@ -615,13 +613,13 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let Comprehension { target, iter, ifs, .. } = &generators[0]; - let iter_val = generator.gen_expr(ctx, iter).unwrap().to_basic_value_enum(ctx, generator); + let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(ctx, generator); let int32 = ctx.ctx.i32_type(); let size_t = generator.get_size_type(ctx.ctx); let zero_size_t = size_t.const_zero(); let zero_32 = int32.const_zero(); - let index = generator.gen_var_alloc(ctx, size_t.into()); + let index = generator.gen_var_alloc(ctx, size_t.into())?; ctx.builder.build_store(index, zero_size_t); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); @@ -664,7 +662,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32]).into_pointer_value(); - let i = generator.gen_store_target(ctx, target); + let i = generator.gen_store_target(ctx, target)?; ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); @@ -699,7 +697,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( list = allocate_list(generator, ctx, elem_ty, length); list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32]).into_pointer_value(); - let counter = generator.gen_var_alloc(ctx, size_t.into()); + let counter = generator.gen_var_alloc(ctx, size_t.into())?; // counter = -1 ctx.builder.build_store(counter, size_t.const_int(u64::max_value(), true)); ctx.builder.build_unconditional_branch(test_bb); @@ -714,11 +712,11 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( .build_gep_and_load(iter_val.into_pointer_value(), &[zero_size_t, zero_32]) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); - generator.gen_assign(ctx, target, val.into()); + generator.gen_assign(ctx, target, val.into())?; } for cond in ifs.iter() { let result = generator - .gen_expr(ctx, cond) + .gen_expr(ctx, cond)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -726,7 +724,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( ctx.builder.build_conditional_branch(result, succ, test_bb); ctx.builder.position_at_end(succ); } - let elem = generator.gen_expr(ctx, elt).unwrap(); + let elem = generator.gen_expr(ctx, elt)?.unwrap(); let i = ctx.builder.build_load(index, "i").into_int_value(); let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }; let val = elem.to_basic_value_enum(ctx, generator); @@ -739,7 +737,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length") }; ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); - list.into() + Ok(list.into()) } else { unreachable!() } @@ -751,16 +749,16 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( left: &Expr>, op: &Operator, right: &Expr>, -) -> ValueEnum<'ctx> { +) -> Result, String> { let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx, generator); - let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx, generator); + let left = generator.gen_expr(ctx, left)?.unwrap().to_basic_value_enum(ctx, generator); + let right = generator.gen_expr(ctx, right)?.unwrap().to_basic_value_enum(ctx, generator); // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do // when doing code generation for function instances - if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { ctx.gen_int_ops(op, left, right) } else if ty1 == ty2 && ctx.primitives.float == ty1 { ctx.gen_float_ops(op, left, right) @@ -783,17 +781,17 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( } else { unimplemented!() } - .into() + .into()) } pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, expr: &Expr>, -) -> Option> { +) -> Result>, String> { let int32 = ctx.ctx.i32_type(); let zero = int32.const_int(0, false); - Some(match &expr.node { + Ok(Some(match &expr.node { ExprKind::Constant { value, .. } => { let ty = expr.custom.unwrap(); ctx.gen_const(generator, value, ty).into() @@ -823,8 +821,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( // we should use memcpy for that instead of generating thousands of stores let elements = elts .iter() - .map(|x| generator.gen_expr(ctx, x).unwrap().to_basic_value_enum(ctx, generator)) - .collect_vec(); + .map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) + .collect::, _>>()?; let ty = if elements.is_empty() { if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) { ctx.get_llvm_type(generator, *ty) @@ -852,8 +850,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ExprKind::Tuple { elts, .. } => { let element_val = elts .iter() - .map(|x| generator.gen_expr(ctx, x).unwrap().to_basic_value_enum(ctx, generator)) - .collect_vec(); + .map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) + .collect::, _>>()?; let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let tuple_ty = ctx.ctx.struct_type(&element_ty, false); let tuple_ptr = ctx.builder.build_alloca(tuple_ty, "tuple"); @@ -871,7 +869,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls - match generator.gen_expr(ctx, value).unwrap() { + match generator.gen_expr(ctx, value)?.unwrap() { ValueEnum::Static(v) => v.get_field(*attr, ctx).unwrap_or_else(|| { let v = v.to_basic_value_enum(ctx, generator); let index = ctx.get_attr_index(value.custom.unwrap(), *attr); @@ -892,7 +890,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ExprKind::BoolOp { op, values } => { // requires conditional branches for short-circuiting... let left = generator - .gen_expr(ctx, &values[0]) + .gen_expr(ctx, &values[0])? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -908,7 +906,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(b_bb); let b = generator - .gen_expr(ctx, &values[1]) + .gen_expr(ctx, &values[1])? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -918,7 +916,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( Boolop::And => { ctx.builder.position_at_end(a_bb); let a = generator - .gen_expr(ctx, &values[1]) + .gen_expr(ctx, &values[1])? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -934,10 +932,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.as_basic_value().into() } - ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right), + ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); - let val = generator.gen_expr(ctx, operand).unwrap().to_basic_value_enum(ctx, generator); + let val = generator.gen_expr(ctx, operand)?.unwrap().to_basic_value_enum(ctx, generator); if ty == ctx.primitives.bool { let val = val.into_int_value(); match op { @@ -984,7 +982,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } ExprKind::Compare { left, ops, comparators } => { izip!(chain(once(left.as_ref()), comparators.iter()), comparators.iter(), ops.iter(),) - .fold(None, |prev, (lhs, rhs, op)| { + .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let ty = ctx.unifier.get_representative(lhs.custom.unwrap()); let current = if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.bool] @@ -995,11 +993,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( BasicValueEnum::IntValue(rhs), ) = ( generator - .gen_expr(ctx, lhs) + .gen_expr(ctx, lhs)? .unwrap() .to_basic_value_enum(ctx, generator), generator - .gen_expr(ctx, rhs) + .gen_expr(ctx, rhs)? .unwrap() .to_basic_value_enum(ctx, generator), ) { @@ -1023,11 +1021,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( BasicValueEnum::FloatValue(rhs), ) = ( generator - .gen_expr(ctx, lhs) + .gen_expr(ctx, lhs)? .unwrap() .to_basic_value_enum(ctx, generator), generator - .gen_expr(ctx, rhs) + .gen_expr(ctx, rhs)? .unwrap() .to_basic_value_enum(ctx, generator), ) { @@ -1048,14 +1046,14 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } else { unimplemented!() }; - prev.map(|v| ctx.builder.build_and(v, current, "cmp")).or(Some(current)) - }) + Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp")).or(Some(current))) + })? .unwrap() .into() // as there should be at least 1 element, it should never be none } ExprKind::IfExp { test, body, orelse } => { let test = generator - .gen_expr(ctx, test) + .gen_expr(ctx, test)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -1065,10 +1063,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(test, then_bb, else_bb); ctx.builder.position_at_end(then_bb); - let a = generator.gen_expr(ctx, body).unwrap().to_basic_value_enum(ctx, generator); + let a = generator.gen_expr(ctx, body)?.unwrap().to_basic_value_enum(ctx, generator); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(else_bb); - let b = generator.gen_expr(ctx, orelse).unwrap().to_basic_value_enum(ctx, generator); + let b = generator.gen_expr(ctx, orelse)?.unwrap().to_basic_value_enum(ctx, generator); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(cont_bb); let phi = ctx.builder.build_phi(a.get_type(), "ifexpr"); @@ -1077,13 +1075,15 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } ExprKind::Call { func, args, keywords } => { let mut params = - args.iter().map(|arg| (None, generator.gen_expr(ctx, arg).unwrap())).collect_vec(); + args.iter().map(|arg| Ok((None, generator.gen_expr(ctx, arg)?.unwrap())) as Result<_, String>) + .collect::, _>>()?; let kw_iter = keywords.iter().map(|kw| { - ( + Ok(( Some(*kw.node.arg.as_ref().unwrap()), - generator.gen_expr(ctx, &kw.node.value).unwrap(), - ) + generator.gen_expr(ctx, &kw.node.value)?.unwrap(), + )) as Result<_, String> }); + let kw_iter = kw_iter.collect::, _>>()?; params.extend(kw_iter); let call = ctx.calls.get(&expr.location.into()); let signature = match call { @@ -1091,22 +1091,23 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( None => { let ty = func.custom.unwrap(); if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) { - sign.borrow().clone() + sign.clone() } else { unreachable!() } } }; - match &func.as_ref().node { + let func = func.as_ref(); + match &func.node { ExprKind::Name { id, .. } => { // TODO: handle primitive casts and function pointers - let fun = ctx.resolver.get_identifier_def(*id).expect("Unknown identifier"); - return generator - .gen_call(ctx, None, (&signature, fun), params) - .map(|v| v.into()); + let fun = ctx.resolver.get_identifier_def(*id).map_err(|e| format!("{} (at {})", e, func.location))?; + return Ok(generator + .gen_call(ctx, None, (&signature, fun), params)? + .map(|v| v.into())); } ExprKind::Attribute { value, attr, .. } => { - let val = generator.gen_expr(ctx, value).unwrap(); + let val = generator.gen_expr(ctx, value)?.unwrap(); let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { @@ -1129,14 +1130,14 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( unreachable!() } }; - return generator + return Ok(generator .gen_call( ctx, Some((value.custom.unwrap(), val)), (&signature, fun_id), params, - ) - .map(|v| v.into()); + )? + .map(|v| v.into())); } _ => unimplemented!(), } @@ -1144,7 +1145,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ExprKind::Subscript { value, slice, .. } => { if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(value.custom.unwrap()) { let v = generator - .gen_expr(ctx, value) + .gen_expr(ctx, value)? .unwrap() .to_basic_value_enum(ctx, generator) .into_pointer_value(); @@ -1153,7 +1154,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( if let ExprKind::Slice { lower, upper, step } = &slice.node { let one = int32.const_int(1, false); let (start, end, step) = - handle_slice_indices(lower, upper, step, ctx, generator, v); + handle_slice_indices(lower, upper, step, ctx, generator, v)?; let length = calculate_len_for_slice_range( ctx, start, @@ -1174,7 +1175,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ); let res_array_ret = allocate_list(generator, ctx, ty, length); let res_ind = - handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret); + handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)?; list_slice_assignment( ctx, generator.get_size_type(ctx.ctx), @@ -1189,7 +1190,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let len = ctx.build_gep_and_load(v, &[zero, int32.const_int(1, false)]) .into_int_value(); let raw_index = generator - .gen_expr(ctx, slice) + .gen_expr(ctx, slice)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -1208,7 +1209,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } } else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { let v = generator - .gen_expr(ctx, value) + .gen_expr(ctx, value)? .unwrap() .to_basic_value_enum(ctx, generator) .into_struct_value(); @@ -1224,7 +1225,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } } .into(), - ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr).into(), + ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr)?.into(), _ => unimplemented!(), - }) + })) } diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 0c9858f..25d2a27 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -28,7 +28,7 @@ pub trait CodeGenerator { obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, - ) -> Option> + ) -> Result>, String> where Self: Sized, { @@ -45,7 +45,7 @@ pub trait CodeGenerator { signature: &FunSignature, def: &TopLevelDef, params: Vec<(Option, ValueEnum<'ctx>)>, - ) -> BasicValueEnum<'ctx> + ) -> Result, String> where Self: Sized, { @@ -65,7 +65,7 @@ pub trait CodeGenerator { obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, &mut TopLevelDef, String), id: usize, - ) -> String { + ) -> Result { gen_func_instance(ctx, obj, fun, id) } @@ -74,7 +74,7 @@ pub trait CodeGenerator { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, expr: &Expr>, - ) -> Option> + ) -> Result>, String> where Self: Sized, { @@ -87,7 +87,7 @@ pub trait CodeGenerator { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, ty: BasicTypeEnum<'ctx>, - ) -> PointerValue<'ctx> { + ) -> Result, String> { gen_var(ctx, ty) } @@ -96,7 +96,7 @@ pub trait CodeGenerator { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, pattern: &Expr>, - ) -> PointerValue<'ctx> + ) -> Result, String> where Self: Sized, { @@ -109,7 +109,8 @@ pub trait CodeGenerator { ctx: &mut CodeGenContext<'ctx, 'a>, target: &Expr>, value: ValueEnum<'ctx>, - ) where + ) -> Result<(), String> + where Self: Sized, { gen_assign(self, ctx, target, value) @@ -118,44 +119,49 @@ pub trait CodeGenerator { /// Generate code for a while expression. /// Return true if the while loop must early return fn gen_while<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) + -> Result<(), String> where Self: Sized, { - gen_while(self, ctx, stmt); + gen_while(self, ctx, stmt) } /// Generate code for a while expression. /// Return true if the while loop must early return fn gen_for<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) + -> Result<(), String> where Self: Sized, { - gen_for(self, ctx, stmt); + gen_for(self, ctx, stmt) } /// Generate code for an if expression. /// Return true if the statement must early return fn gen_if<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) + -> Result<(), String> where Self: Sized, { - gen_if(self, ctx, stmt); + gen_if(self, ctx, stmt) } fn gen_with<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) + -> Result<(), String> where Self: Sized, { - gen_with(self, ctx, stmt); + gen_with(self, ctx, stmt) } /// Generate code for a statement /// Return true if the statement must early return fn gen_stmt<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) + -> Result<(), String> where Self: Sized, { - gen_stmt(self, ctx, stmt); + gen_stmt(self, ctx, stmt) } } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index f6c5b98..cc03bcf 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -116,7 +116,7 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>( /// case Some(e): /// handle_in_bound(e) + 1 /// ,step -/// ) +/// ) /// ``` pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( start: &Option>>>, @@ -125,31 +125,31 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut G, list: PointerValue<'ctx>, -) -> (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>) { +) -> Result<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), String> { // TODO: throw exception when step is 0 let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let one = int32.const_int(1, false); let length = ctx.build_gep_and_load(list, &[zero, one]).into_int_value(); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32"); - match (start, end, step) { + Ok(match (start, end, step) { (s, e, None) => ( s.as_ref().map_or_else( - || int32.const_zero(), + || Ok(int32.const_zero()), |s| handle_slice_index_bound(s, ctx, generator, length), - ), + )?, { let e = e.as_ref().map_or_else( - || length, + || Ok(length), |e| handle_slice_index_bound(e, ctx, generator, length), - ); + )?; ctx.builder.build_int_sub(e, one, "final_end") }, one, ), (s, e, Some(step)) => { let step = generator - .gen_expr(ctx, step) + .gen_expr(ctx, step)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -158,7 +158,7 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( ( match s { Some(s) => { - let s = handle_slice_index_bound(s, ctx, generator, length); + let s = handle_slice_index_bound(s, ctx, generator, length)?; ctx.builder .build_select( ctx.builder.build_and( @@ -181,7 +181,7 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( }, match e { Some(e) => { - let e = handle_slice_index_bound(e, ctx, generator, length); + let e = handle_slice_index_bound(e, ctx, generator, length)?; ctx.builder .build_select( neg, @@ -196,7 +196,7 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( step, ) } - } + }) } /// this function allows index out of range, since python @@ -206,7 +206,7 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut G, length: IntValue<'ctx>, -) -> IntValue<'ctx> { +) -> Result, String> { const SYMBOL: &str = "__nac3_slice_index_bound"; let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { let i32_t = ctx.ctx.i32_type(); @@ -214,13 +214,13 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>( ctx.module.add_function(SYMBOL, fn_t, None) }); - let i = generator.gen_expr(ctx, i).unwrap().to_basic_value_enum(ctx, generator); - ctx.builder + let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator); + Ok(ctx.builder .build_call(func, &[i.into(), length.into()], "bounded_ind") .try_as_basic_value() .left() .unwrap() - .into_int_value() + .into_int_value()) } /// This function handles 'end' **inclusively**. diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 3ca03ef..c34f0fb 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -206,16 +206,26 @@ impl WorkerRegistry { let passes = PassManager::create(&module); pass_builder.populate_function_pass_manager(&passes); + let mut errors = Vec::new(); while let Some(task) = self.receiver.recv().unwrap() { let tmp_module = context.create_module("tmp"); - let result = gen_func(&context, generator, self, builder, tmp_module, task); - builder = result.0; - passes.run_on(&result.2); - module.link_in_module(result.1).unwrap(); - // module = result.1; + match gen_func(&context, generator, self, builder, tmp_module, task) { + Ok(result) => { + builder = result.0; + passes.run_on(&result.2); + module.link_in_module(result.1).unwrap(); + } + Err((old_builder, e)) => { + builder = old_builder; + errors.push(e); + } + } *self.task_count.lock() -= 1; self.wait_condvar.notify_all(); } + if !errors.is_empty() { + panic!("Codegen error: {}", errors.iter().join("\n----------\n")); + } let result = module.verify(); if let Err(err) = result { @@ -267,7 +277,6 @@ fn get_llvm_type<'ctx>( let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } = &*definition.read() { let struct_type = ctx.opaque_struct_type(&name.to_string()); - let fields = fields.borrow(); let fields = fields_list .iter() .map(|f| get_llvm_type(ctx, generator, unifier, top_level, type_cache, fields[&f.0].0)) @@ -309,7 +318,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, -) -> (Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>) { +) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { let top_level_ctx = registry.top_level_ctx.clone(); let static_value_store = registry.static_value_store.clone(); let (mut unifier, primitives) = { @@ -478,8 +487,12 @@ pub fn gen_func<'ctx, G: CodeGenerator>( static_value_store, }; + let mut err = None; for stmt in task.body.iter() { - generator.gen_stmt(&mut code_gen_context, stmt); + if let Err(e) = generator.gen_stmt(&mut code_gen_context, stmt) { + err = Some(e); + break; + } if code_gen_context.is_terminated() { break; } @@ -490,6 +503,9 @@ pub fn gen_func<'ctx, G: CodeGenerator>( } let CodeGenContext { builder, module, .. } = code_gen_context; + if let Some(e) = err { + return Err((builder, e)); + } - (builder, module, fn_val) + Ok((builder, module, fn_val)) } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 31f00b9..247cd29 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -22,33 +22,33 @@ use std::convert::TryFrom; pub fn gen_var<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, ty: BasicTypeEnum<'ctx>, -) -> PointerValue<'ctx> { +) -> Result, String> { // put the alloca in init block let current = ctx.builder.get_insert_block().unwrap(); // position before the last branching instruction... ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap()); let ptr = ctx.builder.build_alloca(ty, "tmp"); ctx.builder.position_at_end(current); - ptr + Ok(ptr) } pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, pattern: &Expr>, -) -> PointerValue<'ctx> { +) -> Result, String> { // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples - match &pattern.node { - ExprKind::Name { id, .. } => ctx.var_assignment.get(id).map(|v| v.0).unwrap_or_else(|| { + Ok(match &pattern.node { + ExprKind::Name { id, .. } => ctx.var_assignment.get(id).map(|v| Ok(v.0) as Result<_, String>).unwrap_or_else(|| { let ptr_ty = ctx.get_llvm_type(generator, pattern.custom.unwrap()); - let ptr = generator.gen_var_alloc(ctx, ptr_ty); + let ptr = generator.gen_var_alloc(ctx, ptr_ty)?; ctx.var_assignment.insert(*id, (ptr, None, 0)); - ptr - }), + Ok(ptr) + })?, ExprKind::Attribute { value, attr, .. } => { let index = ctx.get_attr_index(value.custom.unwrap(), *attr); - let val = generator.gen_expr(ctx, value).unwrap().to_basic_value_enum(ctx, generator); + let val = generator.gen_expr(ctx, value)?.unwrap().to_basic_value_enum(ctx, generator); let ptr = if let BasicValueEnum::PointerValue(v) = val { v } else { @@ -68,12 +68,12 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( ExprKind::Subscript { value, slice, .. } => { let i32_type = ctx.ctx.i32_type(); let v = generator - .gen_expr(ctx, value) + .gen_expr(ctx, value)? .unwrap() .to_basic_value_enum(ctx, generator) .into_pointer_value(); let index = generator - .gen_expr(ctx, slice) + .gen_expr(ctx, slice)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); @@ -85,7 +85,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( } } _ => unreachable!(), - } + }) } pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( @@ -93,8 +93,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, 'a>, target: &Expr>, value: ValueEnum<'ctx>, -) { - match &target.node { +) -> Result<(), String> { + Ok(match &target.node { ExprKind::Tuple { elts, .. } => { if let BasicValueEnum::StructValue(v) = value.to_basic_value_enum(ctx, generator) { for (i, elt) in elts.iter().enumerate() { @@ -102,7 +102,7 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( .builder .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") .unwrap(); - generator.gen_assign(ctx, elt, v.into()); + generator.gen_assign(ctx, elt, v.into())?; } } else { unreachable!() @@ -113,12 +113,12 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( { if let ExprKind::Slice { lower, upper, step } = &slice.node { let ls = generator - .gen_expr(ctx, ls) + .gen_expr(ctx, ls)? .unwrap() .to_basic_value_enum(ctx, generator) .into_pointer_value(); let (start, end, step) = - handle_slice_indices(lower, upper, step, ctx, generator, ls); + handle_slice_indices(lower, upper, step, ctx, generator, ls)?; let value = value.to_basic_value_enum(ctx, generator).into_pointer_value(); let ty = if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) @@ -127,7 +127,7 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( } else { unreachable!() }; - let src_ind = handle_slice_indices(&None, &None, &None, ctx, generator, value); + let src_ind = handle_slice_indices(&None, &None, &None, ctx, generator, value)?; list_slice_assignment( ctx, generator.get_size_type(ctx.ctx), @@ -142,7 +142,7 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( } } _ => { - let ptr = generator.gen_store_target(ctx, target); + let ptr = generator.gen_store_target(ctx, target)?; if let ExprKind::Name { id, .. } = &target.node { let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap(); *counter += 1; @@ -153,14 +153,14 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( let val = value.to_basic_value_enum(ctx, generator); ctx.builder.build_store(ptr, val); } - } + }) } pub fn gen_for<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, -) { +) -> Result<(), String> { if let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node { // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -179,11 +179,11 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( // store loop bb information and restore it later let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); - let iter_val = generator.gen_expr(ctx, iter).unwrap().to_basic_value_enum(ctx, generator); + let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(ctx, generator); if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) { // setup let iter_val = iter_val.into_pointer_value(); - let i = generator.gen_store_target(ctx, target); + let i = generator.gen_store_target(ctx, target)?; let (start, end, step) = destructure_range(ctx, iter_val); ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); ctx.builder.build_unconditional_branch(test_bb); @@ -214,7 +214,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( ); ctx.builder.position_at_end(body_bb); } else { - let counter = generator.gen_var_alloc(ctx, size_t.into()); + let counter = generator.gen_var_alloc(ctx, size_t.into())?; // counter = -1 ctx.builder.build_store(counter, size_t.const_int(u64::max_value(), true)); let len = ctx @@ -235,10 +235,10 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero]) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); - generator.gen_assign(ctx, target, val.into()); + generator.gen_assign(ctx, target, val.into())?; } - gen_block(generator, ctx, body.iter()); + gen_block(generator, ctx, body.iter())?; for (k, (_, _, counter)) in var_assignment.iter() { let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); if counter != counter2 { @@ -250,7 +250,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( } if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); - gen_block(generator, ctx, orelse.iter()); + gen_block(generator, ctx, orelse.iter())?; if !ctx.is_terminated() { ctx.builder.build_unconditional_branch(cont_bb); } @@ -266,13 +266,14 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( } else { unreachable!() } + Ok(()) } pub fn gen_while<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, -) { +) -> Result<(), String> { if let StmtKind::While { test, body, orelse, .. } = &stmt.node { // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -289,14 +290,14 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator>( let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); - let test = generator.gen_expr(ctx, test).unwrap().to_basic_value_enum(ctx, generator); + let test = generator.gen_expr(ctx, test)?.unwrap().to_basic_value_enum(ctx, generator); if let BasicValueEnum::IntValue(test) = test { ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); } else { unreachable!() }; ctx.builder.position_at_end(body_bb); - gen_block(generator, ctx, body.iter()); + gen_block(generator, ctx, body.iter())?; for (k, (_, _, counter)) in var_assignment.iter() { let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); if counter != counter2 { @@ -308,7 +309,7 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator>( } if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); - gen_block(generator, ctx, orelse.iter()); + gen_block(generator, ctx, orelse.iter())?; if !ctx.is_terminated() { ctx.builder.build_unconditional_branch(cont_bb); } @@ -324,13 +325,14 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator>( } else { unreachable!() } + Ok(()) } pub fn gen_if<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, -) { +) -> Result<(), String> { if let StmtKind::If { test, body, orelse, .. } = &stmt.node { // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -349,14 +351,14 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>( }; ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); - let test = generator.gen_expr(ctx, test).unwrap().to_basic_value_enum(ctx, generator); + let test = generator.gen_expr(ctx, test)?.unwrap().to_basic_value_enum(ctx, generator); if let BasicValueEnum::IntValue(test) = test { ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); } else { unreachable!() }; ctx.builder.position_at_end(body_bb); - gen_block(generator, ctx, body.iter()); + gen_block(generator, ctx, body.iter())?; for (k, (_, _, counter)) in var_assignment.iter() { let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); if counter != counter2 { @@ -372,7 +374,7 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>( } if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); - gen_block(generator, ctx, orelse.iter()); + gen_block(generator, ctx, orelse.iter())?; if !ctx.is_terminated() { if cont_bb.is_none() { cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); @@ -392,6 +394,7 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>( } else { unreachable!() } + Ok(()) } pub fn final_proxy<'ctx, 'a>( @@ -442,7 +445,7 @@ pub fn exn_constructor<'ctx, 'a>( _fun: (&FunSignature, DefinitionId), mut args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator -) -> Option> { +) -> Result>, String> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator).into_pointer_value(); let int32 = ctx.ctx.i32_type(); @@ -498,7 +501,7 @@ pub fn exn_constructor<'ctx, 'a>( ctx.builder.build_store(ptr, zero); } } - Some(zelf.into()) + Ok(Some(zelf.into())) } pub fn gen_raise<'ctx, 'a, G: CodeGenerator>( @@ -540,7 +543,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, target: &Stmt>, -) { +) -> Result<(), String> { if let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node { // if we need to generate anything related to exception, we must have personality defined let personality_symbol = ctx.top_level.personality_symbol.as_ref().unwrap(); @@ -564,7 +567,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( let mut old_return = None; let mut old_outer_final = None; let has_cleanup = if !finalbody.is_empty() { - let final_state = generator.gen_var_alloc(ctx, ptr_type.into()); + let final_state = generator.gen_var_alloc(ctx, ptr_type.into())?; old_outer_final = ctx.outer_final.replace((final_state, Vec::new(), Vec::new())); if let Some((continue_target, break_target)) = ctx.loop_target { let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); @@ -622,9 +625,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( } let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn)); let old_unwind = ctx.unwind_target.replace(landingpad); - gen_block(generator, ctx, body.iter()); + gen_block(generator, ctx, body.iter())?; if ctx.builder.get_insert_block().unwrap().get_terminator().is_none() { - gen_block(generator, ctx, orelse.iter()); + gen_block(generator, ctx, orelse.iter())?; } let body = ctx.builder.get_insert_block().unwrap(); // reset old_clauses and old_unwind @@ -723,11 +726,11 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( ctx.builder.position_at_end(handler_bb); if let Some(name) = name { let exn_ty = ctx.get_llvm_type(generator, type_.as_ref().unwrap().custom.unwrap()); - let exn_store = generator.gen_var_alloc(ctx, exn_ty); + let exn_store = generator.gen_var_alloc(ctx, exn_ty)?; ctx.var_assignment.insert(*name, (exn_store, None, 0)); ctx.builder.build_store(exn_store, exn.as_basic_value()); } - gen_block(generator, ctx, body.iter()); + gen_block(generator, ctx, body.iter())?; let current = ctx.builder.get_insert_block().unwrap(); // only need to call end catch if not terminated // otherwise, we already handled in return/break/continue/raise @@ -813,7 +816,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( // exception path let cleanup = cleanup.unwrap(); ctx.builder.position_at_end(cleanup); - gen_block(generator, ctx, finalbody.iter()); + gen_block(generator, ctx, finalbody.iter())?; if !ctx.is_terminated() { ctx.build_call_or_invoke(resume, &[], "resume"); ctx.builder.build_unreachable(); @@ -825,7 +828,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( final_targets.push(tail); let finalizer = ctx.ctx.append_basic_block(current_fun, "try.finally"); ctx.builder.position_at_end(finalizer); - gen_block(generator, ctx, finalbody.iter()); + gen_block(generator, ctx, finalbody.iter())?; if !ctx.is_terminated() { let dest = ctx.builder.build_load(final_state, "final_dest"); ctx.builder.build_indirect_branch(dest, &final_targets); @@ -847,6 +850,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( } ctx.builder.position_at_end(tail); } + Ok(()) } else { unreachable!() } @@ -855,20 +859,21 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( pub fn gen_with<'ctx, 'a, G: CodeGenerator>( _: &mut G, _: &mut CodeGenContext<'ctx, 'a>, - _: &Stmt>, -) -> bool { + stmt: &Stmt>, +) -> Result<(), String> { // TODO: Implement with statement after finishing exceptions - unimplemented!() + Err(format!("With statement with custom types is not yet supported (at {})", stmt.location)) } pub fn gen_return<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, value: &Option>>>, -) { +) -> Result<(), String> { let value = value .as_ref() - .map(|v| generator.gen_expr(ctx, v).unwrap().to_basic_value_enum(ctx, generator)); + .map(|v| generator.gen_expr(ctx, v).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) + .transpose()?; if let Some(return_target) = ctx.return_target { if let Some(value) = value { ctx.builder.build_store(ctx.return_buffer.unwrap(), value); @@ -878,31 +883,32 @@ pub fn gen_return<'ctx, 'a, G: CodeGenerator>( let value = value.as_ref().map(|v| v as &dyn BasicValue); ctx.builder.build_return(value); } + Ok(()) } pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, -) { +) -> Result<(), String> { match &stmt.node { StmtKind::Pass { .. } => {} StmtKind::Expr { value, .. } => { - generator.gen_expr(ctx, value); + generator.gen_expr(ctx, value)?; } StmtKind::Return { value, .. } => { - gen_return(generator, ctx, value); + gen_return(generator, ctx, value)?; } StmtKind::AnnAssign { target, value, .. } => { if let Some(value) = value { - let value = generator.gen_expr(ctx, value).unwrap(); - generator.gen_assign(ctx, target, value); + let value = generator.gen_expr(ctx, value)?.unwrap(); + generator.gen_assign(ctx, target, value)?; } } StmtKind::Assign { targets, value, .. } => { - let value = generator.gen_expr(ctx, value).unwrap(); + let value = generator.gen_expr(ctx, value)?.unwrap(); for target in targets.iter() { - generator.gen_assign(ctx, target, value.clone()); + generator.gen_assign(ctx, target, value.clone())?; } } StmtKind::Continue { .. } => { @@ -911,32 +917,38 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::Break { .. } => { ctx.builder.build_unconditional_branch(ctx.loop_target.unwrap().1); } - StmtKind::If { .. } => generator.gen_if(ctx, stmt), - StmtKind::While { .. } => generator.gen_while(ctx, stmt), - StmtKind::For { .. } => generator.gen_for(ctx, stmt), - StmtKind::With { .. } => generator.gen_with(ctx, stmt), + StmtKind::If { .. } => generator.gen_if(ctx, stmt)?, + StmtKind::While { .. } => generator.gen_while(ctx, stmt)?, + StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, + StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr(generator, ctx, target, op, value); - generator.gen_assign(ctx, target, value); + let value = gen_binop_expr(generator, ctx, target, op, value)?; + generator.gen_assign(ctx, target, value)?; } - StmtKind::Try { .. } => gen_try(generator, ctx, stmt), + StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { - let exc = exc.as_ref().map(|exc| generator.gen_expr(ctx, exc).unwrap().to_basic_value_enum(ctx, generator)); - gen_raise(generator, ctx, exc.as_ref(), stmt.location) + if let Some(exc) = exc { + let exc = generator.gen_expr(ctx, exc)?.unwrap().to_basic_value_enum(ctx, generator); + gen_raise(generator, ctx, Some(&exc), stmt.location); + } else { + gen_raise(generator, ctx, None, stmt.location); + } } _ => unimplemented!(), }; + Ok(()) } pub fn gen_block<'ctx, 'a, 'b, G: CodeGenerator, I: Iterator>>>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, stmts: I, -) { +) -> Result<(), String> { for stmt in stmts { - generator.gen_stmt(ctx, stmt); + generator.gen_stmt(ctx, stmt)?; if ctx.is_terminated() { break; } } + Ok(()) } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index e174f5f..ffcf330 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -18,7 +18,6 @@ use nac3parser::{ parser::parse_program, }; use parking_lot::RwLock; -use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -57,8 +56,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, id: StrRef) -> Option { - self.id_to_def.read().get(&id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Result { + self.id_to_def.read().get(&id).cloned().ok_or_else(|| format!("cannot find symbol `{}`", id)) } fn get_string_id(&self, _: &str) -> i32 { @@ -211,7 +210,7 @@ fn test_simple_call() { ret: primitives.int32, vars: HashMap::new(), }; - let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone()))); + let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone())); let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache); @@ -227,6 +226,7 @@ fn test_simple_call() { instance_to_symbol: HashMap::new(), resolver: None, codegen_callback: None, + loc: None, }))); let resolver = Resolver { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 25ad548..aed4843 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Display}; use std::fmt::Debug; -use std::{cell::RefCell, sync::Arc}; +use std::sync::Arc; use crate::{ codegen::CodeGenContext, @@ -16,7 +16,7 @@ use crate::{ use crate::typecheck::typedef::TypeEnum; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use itertools::{chain, izip}; -use nac3parser::ast::{Expr, StrRef}; +use nac3parser::ast::{Expr, Location, StrRef}; use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] @@ -29,6 +29,25 @@ pub enum SymbolValue { Tuple(Vec), } +impl Display for SymbolValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SymbolValue::I32(i) => write!(f, "{}", i), + SymbolValue::I64(i) => write!(f, "int64({})", i), + SymbolValue::Str(s) => write!(f, "\"{}\"", s), + SymbolValue::Double(d) => write!(f, "{}", d), + SymbolValue::Bool(b) => if *b { + write!(f, "True") + } else { + write!(f, "False") + }, + SymbolValue::Tuple(t) => { + write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::>().join(", ")) + } + } + } +} + pub trait StaticValue { fn get_unique_identifier(&self) -> u64; @@ -105,7 +124,7 @@ pub trait SymbolResolver { ) -> Result; // get the top-level definition of identifiers - fn get_identifier_def(&self, str: StrRef) -> Option; + fn get_identifier_def(&self, str: StrRef) -> Result; fn get_symbol_value<'ctx, 'a>( &self, @@ -154,7 +173,7 @@ pub fn parse_type_annotation( let str_id = ids[8]; let exn_id = ids[9]; - let name_handling = |id: &StrRef, unifier: &mut Unifier| { + let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { Ok(primitives.int32) } else if *id == int64_id { @@ -171,37 +190,37 @@ pub fn parse_type_annotation( Ok(primitives.exception) } else { let obj_id = resolver.get_identifier_def(*id); - if let Some(obj_id) = obj_id { - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if !type_vars.is_empty() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got 0", - type_vars.len() - )); - } - let fields = RefCell::new( - chain( + match obj_id { + Ok(obj_id) => { + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); + } + let fields = chain( fields.iter().map(|(k, v, m)| (*k, (*v, *m))), methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(), - ); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: Default::default(), - })) - } else { - Err("Cannot use function name as type".into()) + ).collect(); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) + } else { + Err(format!("Cannot use function name as type at {}", loc)) + } } - } else { - // it could be a type variable - let ty = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id)?; - if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { - Ok(ty) - } else { - Err(format!("Unknown type annotation {}", id)) + Err(e) => { + let ty = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) + .map_err(|_| format!("Unknown type annotation at {}: {}", loc, e))?; + if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { + Ok(ty) + } else { + Err(format!("Unknown type annotation {} at {}", id, loc)) + } } } } @@ -238,8 +257,7 @@ pub fn parse_type_annotation( }; let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| format!("Unknown type annotation {}", id))?; + .get_identifier_def(*id)?; let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if types.len() != type_vars.len() { @@ -271,8 +289,8 @@ pub fn parse_type_annotation( })); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, - fields: fields.into(), - params: subst.into(), + fields, + params: subst, })) } else { Err("Cannot use function name as type".into()) @@ -281,7 +299,7 @@ pub fn parse_type_annotation( }; match &expr.node { - Name { id, .. } => name_handling(id, unifier), + Name { id, .. } => name_handling(id, expr.location, unifier), Subscript { value, slice, .. } => { if let Name { id, .. } = &value.node { subscript_name_handle(id, slice, unifier) @@ -310,7 +328,7 @@ impl dyn SymbolResolver + Send + Sync { unifier: &mut Unifier, ty: Type, ) -> String { - unifier.stringify( + unifier.internal_stringify( ty, &mut |id| { if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() { @@ -320,6 +338,7 @@ impl dyn SymbolResolver + Send + Sync { } }, &mut |id| format!("var{}", id), + &mut None ) } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e8c84d7..fbf3b2a 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -4,7 +4,6 @@ use crate::{ symbol_resolver::SymbolValue, }; use inkwell::{FloatPredicate, IntPredicate}; -use std::cell::RefCell; type BuiltinInfo = ( Vec<(Arc>, Option)>, @@ -18,7 +17,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; - let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean]); + let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean], Some("N".into()), None); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ @@ -34,62 +33,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ]; let div_by_zero = primitives.1.add_ty(TypeEnum::TObj { obj_id: DefinitionId(10), - fields: RefCell::new(exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect()), + fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), params: Default::default() }); let index_error = primitives.1.add_ty(TypeEnum::TObj { obj_id: DefinitionId(11), - fields: RefCell::new(exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect()), + fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), params: Default::default() }); let exn_cons_args = vec![ FuncArg { name: "msg".into(), ty: string, - default_value: Some(SymbolValue::Str("".into()))}, + default_value: Some(SymbolValue::Str("".into()))}, FuncArg { name: "param0".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param1".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param2".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, ]; - let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: exn_cons_args.clone(), ret: div_by_zero, vars: Default::default() - }))); - let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + })); + let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: exn_cons_args, ret: index_error, vars: Default::default() - }))); + })); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 0, - None, - "int32".into(), - None, + 0, + None, + "int32".into(), + None, + None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 1, - None, - "int64".into(), - None, + 1, + None, + "int64".into(), + None, + None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 2, - None, - "float".into(), - None, + 2, + None, + "float".into(), + None, + None, ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(3, None, "bool".into(), None))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(4, None, "none".into(), None))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(3, None, "bool".into(), None, None))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(4, None, "none".into(), None, None))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 5, - None, - "range".into(), - None, + 5, + None, + "range".into(), + None, + None, ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(6, None, "str".into(), None))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(6, None, "str".into(), None, None))), Arc::new(RwLock::new(TopLevelDef::Class { name: "Exception".into(), object_id: DefinitionId(7), @@ -99,6 +102,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ancestors: vec![], constructor: None, resolver: None, + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "ZeroDivisionError.__init__".into(), @@ -108,7 +112,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "IndexError.__init__".into(), @@ -118,7 +123,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Class { name: "ZeroDivisionError".into(), @@ -132,6 +138,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ], constructor: Some(div_by_zero_signature), resolver: None, + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Class { name: "IndexError".into(), @@ -145,161 +152,165 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ], constructor: Some(index_error_signature), resolver: None, + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], ret: int32, vars: var_map.clone(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let boolean = ctx.primitives.bool; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - if ctx.unifier.unioned(arg_ty, boolean) { - Some( - ctx.builder - .build_int_z_extend( - arg.into_int_value(), - ctx.ctx.i32_type(), - "zext", - ) - .into(), - ) - } else if ctx.unifier.unioned(arg_ty, int32) { - Some(arg) - } else if ctx.unifier.unioned(arg_ty, int64) { - Some( - ctx.builder - .build_int_truncate( - arg.into_int_value(), - ctx.ctx.i32_type(), - "trunc", - ) - .into(), - ) - } else if ctx.unifier.unioned(arg_ty, float) { - let val = ctx - .builder - .build_float_to_signed_int( - arg.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(); - Some(val) - } else { - unreachable!() - } - }, + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok(if ctx.unifier.unioned(arg_ty, boolean) { + Some( + ctx.builder + .build_int_z_extend( + arg.into_int_value(), + ctx.ctx.i32_type(), + "zext", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, int32) { + Some(arg) + } else if ctx.unifier.unioned(arg_ty, int64) { + Some( + ctx.builder + .build_int_truncate( + arg.into_int_value(), + ctx.ctx.i32_type(), + "trunc", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, float) { + let val = ctx + .builder + .build_float_to_signed_int( + arg.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(); + Some(val) + } else { + unreachable!() + }) + }, )))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int64".into(), simple_name: "int64".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], ret: int64, vars: var_map.clone(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let boolean = ctx.primitives.bool; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - if ctx.unifier.unioned(arg_ty, boolean) - || ctx.unifier.unioned(arg_ty, int32) - { - Some( - ctx.builder - .build_int_z_extend( - arg.into_int_value(), - ctx.ctx.i64_type(), - "zext", - ) - .into(), - ) - } else if ctx.unifier.unioned(arg_ty, int64) { - Some(arg) - } else if ctx.unifier.unioned(arg_ty, float) { - let val = ctx - .builder - .build_float_to_signed_int( - arg.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(); - Some(val) - } else { - unreachable!() - } - }, + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok(if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) + { + Some( + ctx.builder + .build_int_z_extend( + arg.into_int_value(), + ctx.ctx.i64_type(), + "zext", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, int64) { + Some(arg) + } else if ctx.unifier.unioned(arg_ty, float) { + let val = ctx + .builder + .build_float_to_signed_int( + arg.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(); + Some(val) + } else { + unreachable!() + }) + }, )))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "float".into(), simple_name: "float".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], ret: float, vars: var_map.clone(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let boolean = ctx.primitives.bool; - let float = ctx.primitives.float; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - if ctx.unifier.unioned(arg_ty, boolean) - || ctx.unifier.unioned(arg_ty, int32) - || ctx.unifier.unioned(arg_ty, int64) - { - let arg = arg.into_int_value(); - let val = ctx - .builder - .build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp") - .into(); - Some(val) - } else if ctx.unifier.unioned(arg_ty, float) { - Some(arg) - } else { - unreachable!() - } - }, + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let boolean = ctx.primitives.bool; + let float = ctx.primitives.float; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok(if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) + || ctx.unifier.unioned(arg_ty, int64) + { + let arg = arg.into_int_value(); + let val = ctx + .builder + .build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp") + .into(); + Some(val) + } else if ctx.unifier.unioned(arg_ty, float) { + Some(arg) + } else { + unreachable!() + }) + }, )))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "round".into(), simple_name: "round".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int32, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -318,25 +329,26 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(), - ) + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(), + )) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "round64".into(), simple_name: "round64".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int64, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -355,21 +367,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(), - ) + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(), + )) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "range".into(), simple_name: "range".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "start".into(), ty: int32, default_value: None }, FuncArg { @@ -386,7 +399,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ], ret: range, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -438,33 +451,35 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ctx.builder.build_store(b, stop); ctx.builder.build_store(c, step); } - Some(ptr.into()) + Ok(Some(ptr.into())) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "str".into(), simple_name: "str".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: string, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "s".into(), ty: string, default_value: None }], ret: string, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - Some(args[0].1.clone().to_basic_value_enum(ctx, generator)) + Ok(Some(args[0].1.clone().to_basic_value_enum(ctx, generator))) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "bool".into(), simple_name: "bool".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], ret: primitives.0.bool, vars: var_map, - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -477,7 +492,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = ctx.primitives.bool; let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - if ctx.unifier.unioned(arg_ty, boolean) { + Ok(if ctx.unifier.unioned(arg_ty, boolean) { Some(arg) } else if ctx.unifier.unioned(arg_ty, int32) { Some(ctx.builder.build_int_compare( @@ -505,18 +520,19 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some(val) } else { unreachable!() - } + }) }, )))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "floor".into(), simple_name: "floor".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int32, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -535,7 +551,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder .build_float_to_signed_int( val.into_float_value(), @@ -543,17 +559,18 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "fptosi", ) .into(), - ) + )) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "floor64".into(), simple_name: "floor64".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int64, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -572,7 +589,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder .build_float_to_signed_int( val.into_float_value(), @@ -580,17 +597,18 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "fptosi", ) .into(), - ) + )) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "ceil".into(), simple_name: "ceil".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int32, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -609,7 +627,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder .build_float_to_signed_int( val.into_float_value(), @@ -617,17 +635,18 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "fptosi", ) .into(), - ) + )) })))), + loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { name: "ceil64".into(), simple_name: "ceil64".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { name: "_".into(), ty: float, default_value: None }], + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], ret: int64, vars: Default::default(), - }))), + })), var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -646,7 +665,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .try_as_basic_value() .left() .unwrap(); - Some( + Ok(Some( ctx.builder .build_float_to_signed_int( val.into_float_value(), @@ -654,25 +673,26 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "fptosi", ) .into(), - ) + )) })))), + loc: None, })), Arc::new(RwLock::new({ - let list_var = primitives.1.get_fresh_var(); + let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); - let arg_ty = primitives.1.get_fresh_var_with_range(&[list, primitives.0.range]); + let arg_ty = primitives.1.get_fresh_var_with_range(&[list, primitives.0.range], Some("I".into()), None); TopLevelDef::Function { name: "len".into(), simple_name: "len".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { - name: "_".into(), + name: "ls".into(), ty: arg_ty.0, default_value: None }], ret: int32, vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), - }))), + })), var_id: vec![arg_ty.1], instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), @@ -682,7 +702,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let range_ty = ctx.primitives.range; let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - if ctx.unifier.unioned(arg_ty, range_ty) { + Ok(if ctx.unifier.unioned(arg_ty, range_ty) { let arg = arg.into_pointer_value(); let (start, end, step) = destructure_range(ctx, arg); Some(calculate_len_for_slice_range(ctx, start, end, step).into()) @@ -695,9 +715,10 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { } else { Some(len.into()) } - } + }) }, )))), + loc: None, } })) ]; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2357e9a..e5a2920 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - use nac3parser::ast::fold::Fold; use crate::{ @@ -102,7 +100,7 @@ impl TopLevelComposer { } for (name, sig, codegen_callback) in builtins { - let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig))); + let fun_sig = unifier.add_ty(TypeEnum::TFunc(sig)); builtin_ty.insert(name, fun_sig); builtin_id.insert(name, DefinitionId(definition_ast_list.len())); definition_ast_list.push(( @@ -115,6 +113,7 @@ impl TopLevelComposer { var_id: Default::default(), resolver: None, codegen_callback: Some(codegen_callback), + loc: None, })), None, )); @@ -192,13 +191,14 @@ impl TopLevelComposer { // since later when registering class method, ast will still be used, // here push None temporarily, later will move the ast inside - let constructor_ty = self.unifier.get_fresh_var().0; + let constructor_ty = self.unifier.get_dummy_var().0; let mut class_def_ast = ( Arc::new(RwLock::new(Self::make_top_level_class_def( class_def_id, resolver.clone(), fully_qualified_class_name, Some(constructor_ty), + Some(ast.location) ))), None, ); @@ -256,19 +256,20 @@ impl TopLevelComposer { }; // dummy method define here - let dummy_method_type = self.unifier.get_fresh_var(); + let dummy_method_type = self.unifier.get_dummy_var().0; class_method_name_def_ids.push(( *method_name, RwLock::new(Self::make_top_level_function_def( global_class_method_name, *method_name, // later unify with parsed type - dummy_method_type.0, + dummy_method_type, resolver.clone(), + Some(b.location), )) .into(), DefinitionId(method_def_id), - dummy_method_type.0, + dummy_method_type, b.clone(), )); } else { @@ -300,9 +301,6 @@ impl TopLevelComposer { } ast::StmtKind::FunctionDef { name, .. } => { - // if self.keyword_list.contains(name) { - // return Err("cannot use keyword as a top level function name".into()); - // } let global_fun_name = if mod_path.is_empty() { name.to_string() } else { @@ -317,7 +315,7 @@ impl TopLevelComposer { } let fun_name = *name; - let ty_to_be_unified = self.unifier.get_fresh_var().0; + let ty_to_be_unified = self.unifier.get_dummy_var().0; // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( @@ -326,6 +324,7 @@ impl TopLevelComposer { // dummy here, unify with correct type later ty_to_be_unified, resolver, + Some(ast.location) )) .into(), Some(ast), @@ -364,8 +363,7 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); let primitives_store = &self.primitives_ty; - // skip 5 to skip analyzing the primitives - for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { + let mut analyze = |class_def: &Arc>, class_ast: &Option| { // only deal with class def here let mut class_def = class_def.write(); let (class_bases_ast, class_def_type_vars, class_resolver) = { @@ -379,7 +377,7 @@ impl TopLevelComposer { unreachable!("must be both class") } } else { - continue; + return Ok(()) } }; let class_resolver = class_resolver.as_ref().unwrap(); @@ -459,6 +457,16 @@ impl TopLevelComposer { _ => continue, } } + Ok(()) + }; + let mut errors = HashSet::new(); + for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { + if let Err(e) = analyze(class_def, class_ast) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); } Ok(()) } @@ -474,9 +482,9 @@ impl TopLevelComposer { let temp_def_list = self.extract_def_list(); let unifier = self.unifier.borrow_mut(); + let primitive_types = self.primitives_ty; - // first, only push direct parent into the list - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { + let mut get_direct_parents = |class_def: &Arc>, class_ast: &Option| { let mut class_def = class_def.write(); let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = @@ -491,7 +499,7 @@ impl TopLevelComposer { unreachable!("must be both class") } } else { - continue; + return Ok(()); } }; let class_resolver = class_resolver.as_ref().unwrap(); @@ -526,7 +534,7 @@ impl TopLevelComposer { class_resolver, &temp_def_list, unifier, - &self.primitives_ty, + &primitive_types, b, vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), )?; @@ -540,17 +548,29 @@ impl TopLevelComposer { )); } } + Ok(()) + }; + + // first, only push direct parent into the list + let mut errors = HashSet::new(); + for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { + if let Err(e) = get_direct_parents(class_def, class_ast) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); } // second, get all ancestors let mut ancestors_store: HashMap> = Default::default(); - for (class_def, _) in self.definition_ast_list.iter().skip(self.builtin_num) { + let mut get_all_ancestors = |class_def: &Arc>| { let class_def = class_def.read(); let (class_ancestors, class_id) = { if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() { (ancestors, *object_id) } else { - continue; + return Ok(()) } }; ancestors_store.insert( @@ -562,6 +582,15 @@ impl TopLevelComposer { Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())? }, ); + Ok(()) + }; + for (class_def, _) in self.definition_ast_list.iter().skip(self.builtin_num) { + if let Err(e) = get_all_ancestors(class_def) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); } // insert the ancestors to the def list @@ -619,9 +648,10 @@ impl TopLevelComposer { let mut type_var_to_concrete_def: HashMap = HashMap::new(); + let mut errors = HashSet::new(); for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) { if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { - Self::analyze_single_class_methods_fields( + if let Err(e) = Self::analyze_single_class_methods_fields( class_def.clone(), &class_ast.as_ref().unwrap().node, &temp_def_list, @@ -629,11 +659,19 @@ impl TopLevelComposer { primitives, &mut type_var_to_concrete_def, (&self.keyword_list, &self.core_config) - )? + ) { + errors.insert(e); + } } } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); + } // handle the inheritanced methods and fields + // Note: we cannot defer error handling til the end of the loop, because there is loop + // carried dependency, ignoring the error (temporarily) will cause all assumptions to break + // and produce weird error messages let mut current_ancestor_depth: usize = 2; loop { let mut finished = true; @@ -668,10 +706,19 @@ impl TopLevelComposer { } // unification of previously assigned typevar - for (ty, def) in type_var_to_concrete_def { + let mut unification_helper = |ty, def| { let target_ty = get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; - unifier.unify(ty, target_ty)?; + unifier.unify(ty, target_ty).map_err(|e| e.to_display(unifier).to_string())?; + Ok(()) as Result<(), String> + }; + for (ty, def) in type_var_to_concrete_def { + if let Err(e) = unification_helper(ty, def) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); } Ok(()) @@ -685,15 +732,15 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); let primitives_store = &self.primitives_ty; - // skip 5 to skip analyzing the primitives - for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { + let mut errors = HashSet::new(); + let mut analyze = |function_def: &Arc>, function_ast: &Option| { let mut function_def = function_def.write(); let function_def = function_def.deref_mut(); let function_ast = if let Some(x) = function_ast.as_ref() { x } else { // if let TopLevelDef::Function { name, .. } = `` - continue; + return Ok(()) }; if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = @@ -701,7 +748,7 @@ impl TopLevelComposer { { if matches!(unifier.get_ty(*dummy_ty).as_ref(), TypeEnum::TFunc(_)) { // already have a function type, is class method, skip - continue; + return Ok(()); } if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { let resolver = resolver.as_ref(); @@ -854,7 +901,7 @@ impl TopLevelComposer { var_id.extend_from_slice(function_var_map .iter() .filter_map(|(id, ty)| { - if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { None } else { Some(*id) @@ -865,18 +912,26 @@ impl TopLevelComposer { ); let function_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } - .into(), )); unifier .unify(*dummy_ty, function_ty) - .map_err(|old| format!("{} (at {})", old, function_ast.location))?; + .map_err(|e| e.at(Some(function_ast.location)).to_display(unifier).to_string())?; } else { unreachable!("must be both function"); } } else { // not top level function def, skip - continue; + return Ok(()) } + Ok(()) + }; + for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { + if let Err(e) = analyze(function_def, function_ast) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")) } Ok(()) } @@ -1022,7 +1077,7 @@ impl TopLevelComposer { // finish handling type vars let dummy_func_arg = FuncArg { name, - ty: unifier.get_fresh_var().0, + ty: unifier.get_dummy_var().0, default_value: match default { None => None, Some(default) => { @@ -1074,13 +1129,13 @@ impl TopLevelComposer { unreachable!("must be type var annotation"); } } - let dummy_return_type = unifier.get_fresh_var().0; + let dummy_return_type = unifier.get_dummy_var().0; type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); dummy_return_type } else { // if do not have return annotation, return none // for uniform handling, still use type annoatation - let dummy_return_type = unifier.get_fresh_var().0; + let dummy_return_type = unifier.get_dummy_var().0; type_var_to_concrete_def.insert( dummy_return_type, TypeAnnotation::Primitive(primitives.none), @@ -1095,7 +1150,7 @@ impl TopLevelComposer { var_id.extend_from_slice(method_var_map .iter() .filter_map(|(id, ty)| { - if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { None } else { Some(*id) @@ -1114,12 +1169,12 @@ impl TopLevelComposer { // unify now since function type is not in type annotation define // which should be fine since type within method_type will be subst later - unifier.unify(method_dummy_ty, method_type)?; + unifier.unify(method_dummy_ty, method_type).map_err(|e| e.to_display(unifier).to_string())?; } ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { - let dummy_field_type = unifier.get_fresh_var().0; + let dummy_field_type = unifier.get_dummy_var().0; // handle Kernel[T], KernelInvariant[T] let (annotation, mutable) = match &annotation.node { @@ -1314,7 +1369,12 @@ impl TopLevelComposer { let init_str_id = "__init__".into(); let mut definition_extension = Vec::new(); let mut constructors = Vec::new(); - for (i, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { + let def_list = self.extract_def_list(); + let primitives_ty = &self.primitives_ty; + let definition_ast_list = &self.definition_ast_list; + let unifier = &mut self.unifier; + let mut errors = HashSet::new(); + let mut analyze = |i, def: &Arc>, ast: &Option| { let class_def = def.read(); if let TopLevelDef::Class { constructor, @@ -1329,29 +1389,29 @@ impl TopLevelComposer { } = &*class_def { let self_type = get_type_from_type_annotation_kinds( - self.extract_def_list().as_slice(), - &mut self.unifier, - &self.primitives_ty, + &def_list, + unifier, + primitives_ty, &make_self_type_annotation(type_vars, *object_id), )?; if ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { // create constructor for these classes - let string = self.primitives_ty.str; - let int64 = self.primitives_ty.int64; - let signature = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + let string = primitives_ty.str; + let int64 = primitives_ty.int64; + let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "msg".into(), ty: string, - default_value: Some(SymbolValue::Str("".into()))}, + default_value: Some(SymbolValue::Str("".into()))}, FuncArg { name: "param0".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param1".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, FuncArg { name: "param2".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + default_value: Some(SymbolValue::I64(0))}, ], ret: self_type, vars: Default::default() - }))); + })); let cons_fun = TopLevelDef::Function { name: format!("{}.{}", class_name, "__init__"), simple_name: init_str_id, @@ -1360,14 +1420,16 @@ impl TopLevelComposer { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), + loc: None }; constructors.push((i, signature, definition_extension.len())); definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); - self.unifier + unifier .unify(constructor.unwrap(), signature) - .map_err(|old| format!("{} (at {})", old, ast.as_ref().unwrap().location))?; - continue; + .map_err(|e| e.at(Some(ast.as_ref().unwrap().location)) + .to_display(unifier).to_string())?; + return Ok(()); } let mut init_id: Option = None; // get the class contructor type correct @@ -1377,8 +1439,7 @@ impl TopLevelComposer { for (name, func_sig, id) in methods { if *name == init_str_id { init_id = Some(*id); - if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { - let FunSignature { args, vars, .. } = &*sig.borrow(); + if let TypeEnum::TFunc(FunSignature { args, vars, ..}) = unifier.get_ty(*func_sig).as_ref() { constructor_args.extend_from_slice(args); type_vars.extend(vars); } else { @@ -1388,18 +1449,17 @@ impl TopLevelComposer { } (constructor_args, type_vars) }; - let contor_type = self.unifier.add_ty(TypeEnum::TFunc( + let contor_type = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: contor_args, ret: self_type, vars: contor_type_vars } - .into(), )); - self.unifier + unifier .unify(constructor.unwrap(), contor_type) - .map_err(|old| format!("{} (at {})", old, ast.as_ref().unwrap().location))?; + .map_err(|e| e.at(Some(ast.as_ref().unwrap().location)).to_display(&unifier).to_string())?; // class field instantiation check if let (Some(init_id), false) = (init_id, fields.is_empty()) { let init_ast = - self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); + definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { if *name != init_str_id { unreachable!("must be init function here") @@ -1418,7 +1478,17 @@ impl TopLevelComposer { } } } + Ok(()) + }; + for (i, (def, ast)) in definition_ast_list.iter().enumerate().skip(self.builtin_num) { + if let Err(e) = analyze(i, def, ast) { + errors.insert(e); + } } + if !errors.is_empty() { + return Err(errors.iter().join("\n---------\n")); + } + for (i, signature, id) in constructors.into_iter() { if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() { methods.push((init_str_id, signature, @@ -1431,10 +1501,14 @@ impl TopLevelComposer { let ctx = Arc::new(self.make_top_level_context()); // type inference inside function body - for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) - { + let def_list = self.extract_def_list(); + let primitives_ty = &self.primitives_ty; + let definition_ast_list = &self.definition_ast_list; + let unifier = &mut self.unifier; + let method_class = &mut self.method_class; + let mut analyze_2 = |id, def: &Arc>, ast: &Option| { if ast.is_none() { - continue; + return Ok(()) } let mut function_def = def.write(); if let TopLevelDef::Function { @@ -1448,19 +1522,18 @@ impl TopLevelComposer { .. } = &mut *function_def { - if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { - let FunSignature { args, ret, vars } = &*func_sig.borrow(); + if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = unifier.get_ty(*signature).as_ref() { // None if is not class method let uninst_self_type = { - if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { - let class_def = self.definition_ast_list.get(class_id.0).unwrap(); + if let Some(class_id) = method_class.get(&DefinitionId(id)) { + let class_def = definition_ast_list.get(class_id.0).unwrap(); let class_def = class_def.0.read(); if let TopLevelDef::Class { type_vars, .. } = &*class_def { let ty_ann = make_self_type_annotation(type_vars, *class_id); let self_ty = get_type_from_type_annotation_kinds( - self.extract_def_list().as_slice(), - &mut self.unifier, - &self.primitives_ty, + &def_list, + unifier, + primitives_ty, &ty_ann, )?; Some((self_ty, type_vars.clone())) @@ -1474,16 +1547,19 @@ impl TopLevelComposer { // carefully handle those with bounds, without bounds and no typevars // if class methods, `vars` also contains all class typevars here let (type_var_subst_comb, no_range_vars) = { - let unifier = &mut self.unifier; let mut no_ranges: Vec = Vec::new(); let var_ids = vars.keys().copied().collect_vec(); let var_combs = vars .iter() .map(|(_, ty)| { unifier.get_instantiations(*ty).unwrap_or_else(|| { - let rigid = unifier.get_fresh_rigid_var().0; - no_ranges.push(rigid); - vec![rigid] + if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) { + let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; + no_ranges.push(rigid); + vec![rigid] + } else { + unreachable!() + } }) }) .multi_cartesian_product() @@ -1501,9 +1577,8 @@ impl TopLevelComposer { for subst in type_var_subst_comb { // for each instance - let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret); + let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); let inst_args = { - let unifier = &mut self.unifier; args.iter() .map(|a| FuncArg { name: a.name, @@ -1513,7 +1588,6 @@ impl TopLevelComposer { .collect_vec() }; let self_type = { - let unifier = &mut self.unifier; uninst_self_type .clone() .map(|(self_type, type_vars)| { @@ -1558,9 +1632,8 @@ impl TopLevelComposer { defined_identifiers: identifiers.clone(), function_data: &mut FunctionData { resolver: resolver.as_ref().unwrap().clone(), - return_type: if self - .unifier - .unioned(inst_ret, self.primitives_ty.none) + return_type: if unifier + .unioned(inst_ret, primitives_ty.none) { None } else { @@ -1569,18 +1642,18 @@ impl TopLevelComposer { // NOTE: allowed type vars bound_variables: no_range_vars.clone(), }, - unifier: &mut self.unifier, + unifier, variable_mapping: { // NOTE: none and function args? let mut result: HashMap = HashMap::new(); - result.insert("None".into(), self.primitives_ty.none); + result.insert("None".into(), primitives_ty.none); if let Some(self_ty) = self_type { result.insert("self".into(), self_ty); } result.extend(inst_args.iter().map(|x| (x.name, x.ty))); result }, - primitives: &self.primitives_ty, + primitives: primitives_ty, virtual_checks: &mut Vec::new(), calls: &mut calls, in_handler: false @@ -1631,8 +1704,8 @@ impl TopLevelComposer { if let TypeEnum::TObj { obj_id, .. } = &*ty { *obj_id } else { - let base_repr = inferencer.unifier.default_stringify(*base); - let subtype_repr = inferencer.unifier.default_stringify(*subtype); + let base_repr = inferencer.unifier.stringify(*base); + let subtype_repr = inferencer.unifier.stringify(*subtype); return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) } }; @@ -1641,8 +1714,8 @@ impl TopLevelComposer { let m = ancestors.iter() .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); if m.is_none() { - let base_repr = inferencer.unifier.default_stringify(*base); - let subtype_repr = inferencer.unifier.default_stringify(*subtype); + let base_repr = inferencer.unifier.stringify(*base); + let subtype_repr = inferencer.unifier.stringify(*subtype); return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) } } else { @@ -1650,9 +1723,9 @@ impl TopLevelComposer { } } } - if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned { - let def_ast_list = &self.definition_ast_list; - let ret_str = self.unifier.stringify( + if !unifier.unioned(inst_ret, primitives_ty.none) && !returned { + let def_ast_list = &definition_ast_list; + let ret_str = unifier.internal_stringify( inst_ret, &mut |id| { if let TopLevelDef::Class { name, .. } = @@ -1664,6 +1737,7 @@ impl TopLevelComposer { } }, &mut |id| format!("tvar{}", id), + &mut None, ); return Err(format!( "expected return type of `{}` in function `{}` (at {})", @@ -1675,7 +1749,7 @@ impl TopLevelComposer { instance_to_stmt.insert( get_subst_key( - &mut self.unifier, + unifier, self_type, &subst, Some(insted_vars), @@ -1691,9 +1765,16 @@ impl TopLevelComposer { } else { unreachable!("must be typeenum::tfunc") } - } else { - continue; } + Ok(()) + }; + for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { + if let Err(e) = analyze_2(id, def, ast) { + errors.insert(e); + } + } + if !errors.is_empty() { + return Err(errors.iter().join("\n----------\n")); } Ok(()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 6d33a33..21341fc 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -18,14 +18,14 @@ impl TopLevelDef { let fields_str = fields .iter() .map(|(n, ty, _)| { - (n.to_string(), unifier.default_stringify(*ty)) + (n.to_string(), unifier.stringify(*ty)) }) .collect_vec(); let methods_str = methods .iter() .map(|(n, ty, id)| { - (n.to_string(), unifier.default_stringify(*ty), *id) + (n.to_string(), unifier.stringify(*ty), *id) }) .collect_vec(); format!( @@ -34,13 +34,13 @@ impl TopLevelDef { ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), - type_vars.iter().map(|id| unifier.default_stringify(*id)).collect_vec(), + type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) } TopLevelDef::Function { name, signature, var_id, .. } => format!( "Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}", name, - unifier.default_stringify(*signature), + unifier.stringify(*signature), { // preserve the order for debug output and test let mut r = var_id.clone(); @@ -117,6 +117,7 @@ impl TopLevelComposer { resolver: Option>, name: StrRef, constructor: Option, + loc: Option ) -> TopLevelDef { TopLevelDef::Class { name, @@ -127,6 +128,7 @@ impl TopLevelComposer { ancestors: Default::default(), constructor, resolver, + loc, } } @@ -136,6 +138,7 @@ impl TopLevelComposer { simple_name: StrRef, ty: Type, resolver: Option>, + loc: Option ) -> TopLevelDef { TopLevelDef::Function { name, @@ -146,6 +149,7 @@ impl TopLevelComposer { instance_to_stmt: Default::default(), resolver, codegen_callback: None, + loc, } } @@ -244,12 +248,8 @@ impl TopLevelComposer { let this = this.as_ref(); let other = unifier.get_ty(other); let other = other.as_ref(); - if let (TypeEnum::TFunc(this_sig), TypeEnum::TFunc(other_sig)) = (this, other) { - let (this_sig, other_sig) = (&*this_sig.borrow(), &*other_sig.borrow()); - let ( - FunSignature { args: this_args, ret: this_ret, vars: _this_vars }, - FunSignature { args: other_args, ret: other_ret, vars: _other_vars }, - ) = (this_sig, other_sig); + if let (TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, ..}), + TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. })) = (this, other) { // check args let args_ok = this_args .iter() diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index b14ba4e..643594d 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -17,7 +17,7 @@ use crate::{ }; use inkwell::values::BasicValueEnum; use itertools::{izip, Itertools}; -use nac3parser::ast::{self, Stmt, StrRef}; +use nac3parser::ast::{self, Location, Stmt, StrRef}; use parking_lot::RwLock; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] @@ -39,7 +39,7 @@ type GenCallCallback = Box< (&FunSignature, DefinitionId), Vec<(Option, ValueEnum<'ctx>)>, &mut dyn CodeGenerator, - ) -> Option> + ) -> Result>, String> + Send + Sync, >; @@ -60,7 +60,7 @@ impl GenCall { fun: (&FunSignature, DefinitionId), args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, - ) -> Option> { + ) -> Result>, String> { (self.fp)(ctx, obj, fun, args, generator) } } @@ -99,6 +99,8 @@ pub enum TopLevelDef { resolver: Option>, // constructor type constructor: Option, + // definition location + loc: Option, }, Function { // prefix for symbol, should be unique globally @@ -124,6 +126,8 @@ pub enum TopLevelDef { resolver: Option>, // custom codegen callback codegen_callback: Option>, + // definition location + loc: Option, }, } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 73e3178..793af15 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -1,14 +1,14 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 541 +assertion_line: 540 expression: res_vec --- [ - "Class {\nname: \"Generic_A\",\nancestors: [\"{class: Generic_A, params: [\\\"var6\\\"]}\", \"{class: B, params: []}\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b=var5], none]\"), (\"fun\", \"fn[[a=int32], var6]\")],\ntype_vars: [\"var6\"]\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"{class: Generic_A, params: [\\\"V\\\"]}\", \"{class: B, params: []}\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: [6]\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a=int32], var6]\",\nvar_id: [6]\n}\n", - "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b=var5], none]\")],\ntype_vars: []\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [6, 17]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.foo\",\nsig: \"fn[[b=var5], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 97380a4..4227911 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -1,17 +1,17 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 541 +assertion_line: 540 expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var5\\\"]}\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t=var5], none]\"), (\"fun\", \"fn[[a=int32, b=var5], list[virtual[B[6->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: [\"var5\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[t=var5], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=int32, b=var5], list[virtual[B[6->bool]]]]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[c=C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"{class: B, params: [\\\"var6\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=int32, b=var5], list[virtual[B[6->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: [\"var6\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"T\\\"]}\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"{class: B, params: [\\\"var6\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"var6\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: [6]\n}\n", - "Function {\nname: \"B.fun\",\nsig: \"fn[[a=int32, b=var5], list[virtual[B[6->bool]]]]\",\nvar_id: [6]\n}\n", - "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: B, params: [\\\"bool\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=int32, b=var5], list[virtual[B[6->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: []\n}\n", + "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: [6]\n}\n", + "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: B, params: [\\\"bool\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 2b1f1ee..6dc65e9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -1,15 +1,15 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 541 +assertion_line: 540 expression: res_vec --- [ - "Function {\nname: \"foo\",\nsig: \"fn[[a=list[int32], b=tuple[var5, float]], A[5->B, 6->bool]]\",\nvar_id: []\n}\n", - "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var5\\\", \\\"var6\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v=var6], none]\"), (\"fun\", \"fn[[a=var5], var6]\")],\ntype_vars: [\"var5\", \"var6\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v=var6], none]\",\nvar_id: [6]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=var5], var6]\",\nvar_id: [6]\n}\n", - "Function {\nname: \"gfun\",\nsig: \"fn[[a=A[5->list[float], 6->int32]], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"T\\\", \\\"V\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [18, 19]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [19, 24]\n}\n", + "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 472c450..7ccdd44 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -1,15 +1,15 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 541 +assertion_line: 540 expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var5\\\", \\\"var6\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a=A[5->float, 6->bool], b=B], none]\"), (\"fun\", \"fn[[a=A[5->float, 6->bool]], A[5->bool, 6->int32]]\")],\ntype_vars: [\"var5\", \"var6\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[a=A[5->float, 6->bool], b=B], none]\",\nvar_id: [6]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=A[5->float, 6->bool]], A[5->bool, 6->int32]]\",\nvar_id: [6]\n}\n", - "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: A, params: [\\\"int64\\\", \\\"bool\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=A[5->float, 6->bool]], A[5->bool, 6->int32]]\"), (\"foo\", \"fn[[b=B], B]\"), (\"bar\", \"fn[[a=A[5->list[B], 6->int32]], tuple[A[5->virtual[A[5->B, 6->int32]], 6->bool], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var5\\\", \\\"var6\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"var5\", \"var6\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: [6]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: [6]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: A, params: [\\\"int64\\\", \\\"bool\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.foo\",\nsig: \"fn[[b=B], B]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.bar\",\nsig: \"fn[[a=A[5->list[B], 6->int32]], tuple[A[5->virtual[A[5->B, 6->int32]], 6->bool], B]]\",\nvar_id: []\n}\n", + "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", + "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index b2bd464..cf18240 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -1,19 +1,19 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 541 +assertion_line: 540 expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var5, b=var6], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[b=B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a=var5, b=var6], none]\",\nvar_id: [6]\n}\n", - "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var5, b=var6], none]\")],\ntype_vars: []\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [25]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var5, b=var6], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"C.fun\",\nsig: \"fn[[b=B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"foo\",\nsig: \"fn[[a=A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a=var5], var6]\",\nvar_id: [6]\n}\n", + "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [33]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 99a6ed2..e35846f 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -61,8 +61,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, id: StrRef) -> Option { - self.0.id_to_def.lock().get(&id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Result { + self.0.id_to_def.lock().get(&id).cloned().ok_or("Unknown identifier".to_string()) } fn get_string_id(&self, _: &str) -> i32 { @@ -129,9 +129,9 @@ fn test_simple_register(source: Vec<&str>) { "}, ], vec![ - "fn[[a=0], 0]", - "fn[[a=2], 4]", - "fn[[b=1], 0]", + "fn[[a:0], 0]", + "fn[[a:2], 4]", + "fn[[b:1], 0]", ], vec![ "fun", @@ -172,7 +172,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s let ty_str = composer .unifier - .stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string()); + .internal_stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None); assert_eq!(ty_str, tys[i]); assert_eq!(name, names[i]); } @@ -752,7 +752,7 @@ fn make_internal_resolver_with_tvar( .into_iter() .map(|(name, range)| { (name, { - let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice()); + let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None); if print { println!("{}: {:?}, tvar{}", name, ty, id); } @@ -779,9 +779,9 @@ impl<'a> Fold> for TypeToStringFolder<'a> { type Error = String; fn map_user(&mut self, user: Option) -> Result { Ok(if let Some(ty) = user { - self.unifier.stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| { + self.unifier.internal_stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| { format!("tvar{}", id.to_string()) - }) + }, &mut None) } else { "None".into() }) diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 72bc59c..0064ac4 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,6 +1,3 @@ -use std::cell::RefCell; - -use crate::typecheck::typedef::TypeVarMeta; use super::*; #[derive(Clone, Debug)] @@ -23,7 +20,7 @@ impl TypeAnnotation { pub fn stringify(&self, unifier: &mut Unifier) -> String { use TypeAnnotation::*; match self { - Primitive(ty) | TypeVar(ty) => unifier.default_stringify(*ty), + Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), CustomClass { id, params } => { let class_name = match unifier.top_level { Some(ref top) => if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() { @@ -65,7 +62,7 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Primitive(primitives.str)) } else if id == &"Exception".into() { Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }) - } else if let Some(obj_id) = resolver.get_identifier_def(*id) { + } else if let Ok(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); if let Some(def_read) = def_read { @@ -92,6 +89,8 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { + let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; + unifier.unify(var, ty).unwrap(); Ok(TypeAnnotation::TypeVar(ty)) } else { Err(format!( @@ -113,8 +112,7 @@ pub fn parse_ast_to_type_annotation_kinds( return Err(format!("keywords cannot be class name (at {})", expr.location)); } let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| "unknown class name".to_string())?; + .get_identifier_def(*id)?; let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); if let Some(def_read) = def_read { @@ -293,14 +291,14 @@ pub fn get_type_from_type_annotation_kinds( // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check let mut result: HashMap = HashMap::new(); for (tvar, p) in type_vars.iter().zip(param_ty) { - if let TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic } = + if let TypeEnum::TVar { id, range, fields: None, name, loc } = unifier.get_ty(*tvar).as_ref() { let ok: bool = { // create a temp type var and unify to check compatibility p == *tvar || { let temp = - unifier.get_fresh_var_with_range(range.borrow().as_slice()); + unifier.get_fresh_var_with_range(range.as_slice(), *name, *loc); unifier.unify(temp.0, p).is_ok() } }; @@ -309,10 +307,11 @@ pub fn get_type_from_type_annotation_kinds( } else { return Err(format!( "cannot apply type {} to type variable with id {:?}", - unifier.stringify( + unifier.internal_stringify( p, &mut |id| format!("class{}", id), - &mut |id| format!("tvar{}", id) + &mut |id| format!("tvar{}", id), + &mut None ), *id )); @@ -338,7 +337,7 @@ pub fn get_type_from_type_annotation_kinds( Ok(unifier.add_ty(TypeEnum::TObj { obj_id: *obj_id, - fields: RefCell::new(tobj_fields), + fields: tobj_fields, params: subst.into(), })) } @@ -438,8 +437,8 @@ pub fn check_overload_type_annotation_compatible( let b = unifier.get_ty(*b); let b = b.deref(); if let ( - TypeEnum::TVar { id: a, meta: TypeVarMeta::Generic, .. }, - TypeEnum::TVar { id: b, meta: TypeVarMeta::Generic, .. }, + TypeEnum::TVar { id: a, fields: None, .. }, + TypeEnum::TVar { id: b, fields: None, .. }, ) = (a, b) { a == b diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 8de79d2..cced74e 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -2,10 +2,10 @@ use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; -use nac3parser::ast; +use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{Cmpop, Operator, Unaryop}; -use std::borrow::Borrow; use std::collections::HashMap; +use std::rc::Rc; pub fn binop_name(op: &Operator) -> &'static str { match op { @@ -64,6 +64,25 @@ pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { } } +pub(super) fn with_fields(unifier: &mut Unifier, ty: Type, f: F) + where F: FnOnce(&mut Unifier, &mut HashMap) +{ + let (id, mut fields, params) = if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) { + (*obj_id, fields.clone(), params.clone()) + } else { + unreachable!() + }; + f(unifier, &mut fields); + unsafe { + let unification_table = unifier.get_unification_table(); + unification_table.set_value(ty, Rc::new(TypeEnum::TObj { + obj_id: id, + fields, + params, + })); + } +} + pub fn impl_binop( unifier: &mut Unifier, store: &PrimitiveStore, @@ -72,11 +91,11 @@ pub fn impl_binop( ret_ty: Type, ops: &[ast::Operator], ) { - if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + with_fields(unifier, ty, |unifier, fields| { let (other_ty, other_var_id) = if other_ty.len() == 1 { (other_ty[0], None) } else { - let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty); + let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); (ty, Some(var_id)) }; let function_vars = if let Some(var_id) = other_var_id { @@ -85,7 +104,7 @@ pub fn impl_binop( HashMap::new() }; for op in ops { - fields.borrow_mut().insert(binop_name(op).into(), { + fields.insert(binop_name(op).into(), { ( unifier.add_ty(TypeEnum::TFunc( FunSignature { @@ -97,13 +116,12 @@ pub fn impl_binop( name: "other".into(), }], } - .into(), )), false, ) }); - fields.borrow_mut().insert(binop_assign_name(op).into(), { + fields.insert(binop_assign_name(op).into(), { ( unifier.add_ty(TypeEnum::TFunc( FunSignature { @@ -115,39 +133,33 @@ pub fn impl_binop( name: "other".into(), }], } - .into(), )), false, ) }); } - } else { - unreachable!("") - } + }); } pub fn impl_unaryop( unifier: &mut Unifier, - _store: &PrimitiveStore, ty: Type, ret_ty: Type, ops: &[ast::Unaryop], ) { - if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + with_fields(unifier, ty, |unifier, fields| { for op in ops { - fields.borrow_mut().insert( + fields.insert( unaryop_name(op).into(), ( unifier.add_ty(TypeEnum::TFunc( - FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(), + FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] } )), false, ), ); } - } else { - unreachable!() - } + }); } pub fn impl_cmpop( @@ -157,9 +169,9 @@ pub fn impl_cmpop( other_ty: Type, ops: &[ast::Cmpop], ) { - if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + with_fields(unifier, ty, |unifier, fields| { for op in ops { - fields.borrow_mut().insert( + fields.insert( comparison_name(op).unwrap().into(), ( unifier.add_ty(TypeEnum::TFunc( @@ -172,15 +184,12 @@ pub fn impl_cmpop( name: "other".into(), }], } - .into(), )), false, ), ); } - } else { - unreachable!() - } + }); } /// Add, Sub, Mult @@ -257,18 +266,18 @@ pub fn impl_mod( } /// UAdd, USub -pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub]) +pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { + impl_unaryop(unifier, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub]) } /// Invert -pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert]) +pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { + impl_unaryop(unifier, ty, ty, &[ast::Unaryop::Invert]) } /// Not pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not]) + impl_unaryop(unifier, ty, store.bool, &[ast::Unaryop::Not]) } /// Lt, LtE, Gt, GtE diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index db7bcae..69b0ede 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -2,4 +2,5 @@ mod function_check; pub mod magic_methods; pub mod type_inferencer; pub mod typedef; +pub mod type_error; mod unification_table; diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs new file mode 100644 index 0000000..3533d6c --- /dev/null +++ b/nac3core/src/typecheck/type_error.rs @@ -0,0 +1,177 @@ +use std::fmt::Display; +use std::collections::HashMap; + +use crate::typecheck::typedef::TypeEnum; + +use super::typedef::{Type, Unifier, RecordKey}; +use nac3parser::ast::{Location, StrRef}; + +#[derive(Debug, Clone)] +pub enum TypeErrorKind { + TooManyArguments { + expected: usize, + got: usize, + }, + MissingArgs(String), + UnknownArgName(StrRef), + IncorrectArgType { + name: StrRef, + expected: Type, + got: Type, + }, + FieldUnificationError { + field: RecordKey, + types: (Type, Type), + loc: (Option, Option), + }, + IncompatibleRange(Type, Vec), + IncompatibleTypes(Type, Type), + MutationError(RecordKey, Type), + NoSuchField(RecordKey, Type), + TupleIndexOutOfBounds { + index: i32, + len: i32, + }, + RequiresTypeAnn, + PolymorphicFunctionPointer, +} + +#[derive(Debug, Clone)] +pub struct TypeError { + pub kind: TypeErrorKind, + pub loc: Option, +} + +impl TypeError { + pub fn new(kind: TypeErrorKind, loc: Option) -> TypeError { + TypeError { kind, loc } + } + + pub fn at(mut self, loc: Option) -> TypeError { + self.loc = self.loc.or(loc); + self + } + + pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError { + DisplayTypeError { + err: self, + unifier + } + } +} + +pub struct DisplayTypeError<'a> { + pub err: TypeError, + pub unifier: &'a Unifier +} + +fn loc_to_str(loc: Option) -> String { + match loc { + Some(loc) => format!("(in {})", loc), + None => "".to_string(), + } +} + +impl<'a> Display for DisplayTypeError<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + use TypeErrorKind::*; + let mut notes = Some(HashMap::new()); + match &self.err.kind { + TooManyArguments { expected, got } => { + write!(f, "Too many arguments. Expected {} but got {}", expected, got) + } + MissingArgs(args) => { + write!(f, "Missing arguments: {}", args) + } + UnknownArgName(name) => { + write!(f, "Unknown argument name: {}", name) + } + IncorrectArgType { + name, + expected, + got, + } => { + let expected = self.unifier.stringify_with_notes(*expected, &mut notes); + let got = self.unifier.stringify_with_notes(*got, &mut notes); + write!( + f, + "Incorrect argument type for {}. Expected {}, but got {}", + name, expected, got + ) + }, + FieldUnificationError { field, types, loc } => { + let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); + let rhs = self.unifier.stringify_with_notes(types.1, &mut notes); + write!( + f, + "Unable to unify field {}: Got types {}{} and {}{}", + field, lhs, loc_to_str(loc.0), rhs, loc_to_str(loc.1) + ) + } + IncompatibleRange(t, ts) => { + let t = self.unifier.stringify_with_notes(*t, &mut notes); + let ts = ts.iter().map(|t| self.unifier.stringify_with_notes(*t, &mut notes)).collect::>(); + write!(f, "Expected any one of these types: {}, but got {}", ts.join(", "), t) + } + IncompatibleTypes(t1, t2) => { + let type1 = self.unifier.get_ty_immutable(*t1); + let type2 = self.unifier.get_ty_immutable(*t2); + match (&*type1, &*type2) { + (TypeEnum::TCall(calls), _) => { + let loc = self.unifier.calls[calls[0].0].loc; + let result = write!(f, "{} is not callable", self.unifier.stringify_with_notes(*t2, &mut notes)); + if let Some(loc) = loc { + result?; + write!(f, " (in {})", loc)?; + return Ok(()) + } + result + } + (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) if ty1.len() != ty2.len() => { + let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); + let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); + write!(f, "Tuple length mismatch: got {} and {}", t1, t2) + } + _ => { + let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); + let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); + write!(f, "Incompatible types: {} and {}", t1, t2) + } + } + } + MutationError(name, t) => { + if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty_immutable(*t) { + write!(f, "Cannot assign to an element of a tuple") + } else { + let t = self.unifier.stringify_with_notes(*t, &mut notes); + write!(f, "Cannot assign to field {} of {}, which is immutable", name, t) + } + } + NoSuchField(name, t) => { + let t = self.unifier.stringify_with_notes(*t, &mut notes); + write!(f, "`{}::{}` field does not exist", t, name) + } + TupleIndexOutOfBounds { index, len } => { + write!(f, "Tuple index out of bounds. Got {} but tuple has only {} elements", index, len) + } + RequiresTypeAnn => { + write!(f, "Unable to infer virtual object type: Type annotation required") + } + PolymorphicFunctionPointer => { + write!(f, "Polymorphic function pointers is not supported") + } + }?; + if let Some(loc) = self.err.loc { + write!(f, " at {}", loc)?; + } + let notes = notes.unwrap(); + if !notes.is_empty() { + write!(f, "\n\nNotes:")?; + for line in notes.values() { + write!(f, "\n {}", line)?; + } + } + Ok(()) + } +} + diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 2634fa7..11f86d8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; -use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; +use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier, RecordField}; use super::{magic_methods::*, typedef::CallId}; use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; use itertools::izip; @@ -147,7 +147,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { self.defined_identifiers.insert(name); } if let Some(old_typ) = self.variable_mapping.insert(name, typ) { - self.unifier.unify(old_typ, typ)?; + let loc = handler.location; + self.unifier.unify(old_typ, typ).map_err(|e| e.at(Some(loc)) + .to_display(self.unifier).to_string())?; } } let mut type_ = naive_folder.fold_expr(*type_)?; @@ -249,7 +251,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } }) .collect(); - let targets = targets?; + let loc = node.location; + let targets = targets.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?; return Ok(Located { location: node.location, node: ast::StmtKind::Assign { @@ -310,11 +313,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { // if we can simply unify without creating new types... let mut fast_path = false; if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { - let fields = fields.borrow(); fast_path = true; if let Some(enter) = fields.get(&"__enter__".into()).cloned() { if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter.0) { - let signature = signature.borrow(); if !signature.args.is_empty() { return report_error( "__enter__ method should take no argument other than self", @@ -343,7 +344,6 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } if let Some(exit) = fields.get(&"__exit__".into()).cloned() { if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit.0) { - let signature = signature.borrow(); if !signature.args.is_empty() { return report_error( "__exit__ method should take no argument other than self", @@ -361,24 +361,24 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } if !fast_path { - let enter = TypeEnum::TFunc(RefCell::new(FunSignature { + let enter = TypeEnum::TFunc(FunSignature { args: vec![], ret: item.optional_vars.as_ref().map_or_else( - || self.unifier.get_fresh_var().0, + || self.unifier.get_dummy_var().0, |var| var.custom.unwrap(), ), vars: Default::default(), - })); + }); let enter = self.unifier.add_ty(enter); - let exit = TypeEnum::TFunc(RefCell::new(FunSignature { + let exit = TypeEnum::TFunc(FunSignature { args: vec![], - ret: self.unifier.get_fresh_var().0, + ret: self.unifier.get_dummy_var().0, vars: Default::default(), - })); + }); let exit = self.unifier.add_ty(exit); let mut fields = HashMap::new(); - fields.insert("__enter__".into(), (enter, false)); - fields.insert("__exit__".into(), (exit, false)); + fields.insert("__enter__".into(), RecordField::new(enter, false, None)); + fields.insert("__exit__".into(), RecordField::new(exit, false, None)); let record = self.unifier.add_record(fields); self.unify(ty, record, &stmt.location)?; } @@ -455,8 +455,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::ExprKind::Compare { left, ops, comparators } => { Some(self.infer_compare(left, ops, comparators)?) } - ast::ExprKind::Subscript { value, slice, .. } => { - Some(self.infer_subscript(value.as_ref(), slice.as_ref())?) + ast::ExprKind::Subscript { value, slice, ctx, .. } => { + Some(self.infer_subscript(value.as_ref(), slice.as_ref(), ctx)?) } ast::ExprKind::IfExp { test, body, orelse } => { Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) @@ -477,11 +477,11 @@ impl<'a> Inferencer<'a> { /// Constrain a <: b /// Currently implemented as unification fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { - self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) + self.unify(a, b, location) } fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { - self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) + self.unifier.unify(a, b).map_err(|e| e.at(Some(*location)).to_display(self.unifier).to_string()) } fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> { @@ -511,17 +511,17 @@ impl<'a> Inferencer<'a> { ret: Option, ) -> InferenceResult { if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { - if class_params.borrow().is_empty() { - if let Some(ty) = fields.borrow().get(&method) { + if class_params.is_empty() { + if let Some(ty) = fields.get(&method) { let ty = ty.0; if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - let sign = sign.borrow(); if sign.vars.is_empty() { let call = Call { posargs: params, kwargs: HashMap::new(), ret: sign.ret, fun: RefCell::new(None), + loc: Some(location), }; if let Some(ret) = ret { self.unifier.unify(sign.ret, ret).unwrap(); @@ -534,25 +534,26 @@ impl<'a> Inferencer<'a> { .rev() .collect(); self.unifier - .unify_call(&call, ty, &sign, &required) - .map_err(|old| format!("{} at {}", old, location))?; + .unify_call(&call, ty, sign, &required) + .map_err(|e| e.at(Some(location)).to_display(self.unifier).to_string())?; return Ok(sign.ret); } } } } } - let ret = ret.unwrap_or_else(|| self.unifier.get_fresh_var().0); + let ret = ret.unwrap_or_else(|| self.unifier.get_dummy_var().0); let call = self.unifier.add_call(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None), + loc: Some(location), }); self.calls.insert(location.into(), call); - let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); - let fields = once((method, (call, false))).collect(); + let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); + let fields = once((method.into(), RecordField::new(call, false, Some(location)))).collect(); let record = self.unifier.add_record(fields); self.constrain(obj, record, &location)?; Ok(ret) @@ -585,10 +586,10 @@ impl<'a> Inferencer<'a> { } } let fn_args: Vec<_> = - args.args.iter().map(|v| (v.node.arg, self.unifier.get_fresh_var().0)).collect(); + args.args.iter().map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)).collect(); let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().cloned()); - let ret = self.unifier.get_fresh_var().0; + let ret = self.unifier.get_dummy_var().0; let mut new_context = Inferencer { function_data: self.function_data, @@ -620,7 +621,7 @@ impl<'a> Inferencer<'a> { Ok(Located { location, node: ExprKind::Lambda { args: args.into(), body: body.into() }, - custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))), + custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))), }) } @@ -725,7 +726,7 @@ impl<'a> Inferencer<'a> { &arg, )? } else { - self.unifier.get_fresh_var().0 + self.unifier.get_dummy_var().0 }; self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location)); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); @@ -774,7 +775,6 @@ impl<'a> Inferencer<'a> { .collect::, _>>()?; if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { - let sign = sign.borrow(); if sign.vars.is_empty() { let call = Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), @@ -784,6 +784,7 @@ impl<'a> Inferencer<'a> { .collect(), fun: RefCell::new(None), ret: sign.ret, + loc: Some(location) }; let required: Vec<_> = sign .args @@ -793,8 +794,8 @@ impl<'a> Inferencer<'a> { .rev() .collect(); self.unifier - .unify_call(&call, func.custom.unwrap(), &sign, &required) - .map_err(|old| format!("{} at {}", old, location))?; + .unify_call(&call, func.custom.unwrap(), sign, &required) + .map_err(|e| e.at(Some(location)).to_display(self.unifier).to_string())?; return Ok(Located { location, custom: Some(sign.ret), @@ -803,7 +804,7 @@ impl<'a> Inferencer<'a> { } } - let ret = self.unifier.get_fresh_var().0; + let ret = self.unifier.get_dummy_var().0; let call = self.unifier.add_call(Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords @@ -812,9 +813,10 @@ impl<'a> Inferencer<'a> { .collect(), fun: RefCell::new(None), ret, + loc: Some(location) }); self.calls.insert(location.into(), call); - let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); + let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); self.unify(func.custom.unwrap(), call, &func.location)?; Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) @@ -831,7 +833,7 @@ impl<'a> Inferencer<'a> { .resolver .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .unwrap_or_else(|_| { - let ty = unifier.get_fresh_var().0; + let ty = unifier.get_dummy_var().0; variable_mapping.insert(id, ty); ty })) @@ -867,7 +869,7 @@ impl<'a> Inferencer<'a> { } fn infer_list(&mut self, elts: &[ast::Expr>]) -> InferenceResult { - let (ty, _) = self.unifier.get_fresh_var(); + let ty = self.unifier.get_dummy_var().0; for t in elts.iter() { self.unify(ty, t.custom.unwrap(), &t.location)?; } @@ -888,7 +890,6 @@ impl<'a> Inferencer<'a> { let ty = value.custom.unwrap(); if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { // just a fast path - let fields = fields.borrow(); match (fields.get(&attr), ctx == &ExprContext::Store) { (Some((ty, true)), _) => Ok(*ty), (Some((ty, false)), false) => Ok(*ty), @@ -898,8 +899,9 @@ impl<'a> Inferencer<'a> { (None, _) => report_error(&format!("No such field {}", attr), value.location), } } else { - let (attr_ty, _) = self.unifier.get_fresh_var(); - let fields = once((attr, (attr_ty, ctx == &ExprContext::Store))).collect(); + let attr_ty = self.unifier.get_dummy_var().0; + let fields = once((attr.into(), RecordField::new( + attr_ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); let record = self.unifier.add_record(fields); self.constrain(value.custom.unwrap(), record, &value.location)?; Ok(attr_ty) @@ -965,8 +967,9 @@ impl<'a> Inferencer<'a> { &mut self, value: &ast::Expr>, slice: &ast::Expr>, + ctx: &ExprContext, ) -> InferenceResult { - let ty = self.unifier.get_fresh_var().0; + let ty = self.unifier.get_dummy_var().0; match &slice.node { ast::ExprKind::Slice { lower, upper, step } => { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { @@ -983,8 +986,9 @@ impl<'a> Inferencer<'a> { None => None, }; let ind = ind.ok_or_else(|| "Index must be int32".to_string())?; - let map = once((ind, ty)).collect(); - let seq = self.unifier.add_sequence(map); + let map = once((ind.into(), RecordField::new( + ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); + let seq = self.unifier.add_record(map); self.constrain(value.custom.unwrap(), seq, &value.location)?; Ok(ty) } @@ -1005,9 +1009,7 @@ impl<'a> Inferencer<'a> { orelse: &ast::Expr>, ) -> InferenceResult { self.constrain(test.custom.unwrap(), self.primitives.bool, &test.location)?; - let ty = self.unifier.get_fresh_var().0; - self.constrain(body.custom.unwrap(), ty, &body.location)?; - self.constrain(orelse.custom.unwrap(), ty, &orelse.location)?; - Ok(ty) + self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?; + Ok(body.custom.unwrap()) } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index e845985..c922ed1 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,4 +1,4 @@ -use super::super::typedef::*; +use super::super::{typedef::*, magic_methods::with_fields}; use super::*; use crate::{ codegen::CodeGenContext, @@ -40,8 +40,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, id: StrRef) -> Option { - self.id_to_def.get(&id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Result { + self.id_to_def.get(&id).cloned().ok_or("Unknown identifier".to_string()) } fn get_string_id(&self, _: &str) -> i32 { @@ -69,7 +69,7 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); - if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) { + with_fields(&mut unifier, int32, |unifier, fields| { let add_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], @@ -78,8 +78,8 @@ impl TestEnvironment { } .into(), )); - fields.borrow_mut().insert("__add__".into(), (add_ty, false)); - } + fields.insert("__add__".into(), (add_ty, false)); + }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), fields: HashMap::new().into(), @@ -170,7 +170,7 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); - if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) { + with_fields(&mut unifier, int32, |unifier, fields| { let add_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], @@ -179,8 +179,8 @@ impl TestEnvironment { } .into(), )); - fields.borrow_mut().insert("__add__".into(), (add_ty, false)); - } + fields.insert("__add__".into(), (add_ty, false)); + }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), fields: HashMap::new().into(), @@ -230,6 +230,7 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, + loc: None }) .into(), ); @@ -238,7 +239,7 @@ impl TestEnvironment { let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; - let (v0, id) = unifier.get_fresh_var(); + let (v0, id) = unifier.get_dummy_var(); let foo_ty = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(defs + 1), @@ -255,6 +256,7 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, + loc: None, }) .into(), ); @@ -293,6 +295,7 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, + loc: None }) .into(), ); @@ -322,6 +325,7 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, + loc: None }) .into(), ); @@ -416,7 +420,7 @@ impl TestEnvironment { c = 1.234 d = b(c) "}, - [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), + [("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), &[] ; "lambda test")] #[test_case(indoc! {" @@ -425,7 +429,7 @@ impl TestEnvironment { a = b c = b(1) "}, - [("a", "fn[[x=int32], int32]"), ("b", "fn[[x=int32], int32]"), ("c", "int32")].iter().cloned().collect(), + [("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].iter().cloned().collect(), &[] ; "lambda test 2")] #[test_case(indoc! {" @@ -441,8 +445,8 @@ impl TestEnvironment { b(123) "}, - [("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"), - ("d", "int32"), ("foo1", "Foo[1->bool]"), ("foo2", "Foo[1->int32]")].iter().cloned().collect(), + [("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"), + ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(), &[] ; "obj test")] #[test_case(indoc! {" @@ -485,33 +489,37 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); for (k, v) in inferencer.variable_mapping.iter() { - let name = inferencer.unifier.stringify( + let name = inferencer.unifier.internal_stringify( *v, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); println!("{}: {}", k, name); } for (k, v) in mapping.iter() { let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); - let name = inferencer.unifier.stringify( + let name = inferencer.unifier.internal_stringify( *ty, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } assert_eq!(inferencer.virtual_checks.len(), virtuals.len()); for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { - let a = inferencer.unifier.stringify( + let a = inferencer.unifier.internal_stringify( *a, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); - let b = inferencer.unifier.stringify( + let b = inferencer.unifier.internal_stringify( *b, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); assert_eq!(&a, x); @@ -627,19 +635,21 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); for (k, v) in inferencer.variable_mapping.iter() { - let name = inferencer.unifier.stringify( + let name = inferencer.unifier.internal_stringify( *v, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); println!("{}: {}", k, name); } for (k, v) in mapping.iter() { let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); - let name = inferencer.unifier.stringify( + let name = inferencer.unifier.internal_stringify( *ty, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), + &mut None ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 8eb6276..516191d 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,13 +1,15 @@ use itertools::{zip, Itertools}; use std::cell::RefCell; use std::collections::HashMap; +use std::fmt::Display; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use nac3parser::ast::StrRef; +use nac3parser::ast::{StrRef, Location}; use super::unification_table::{UnificationKey, UnificationTable}; +use super::type_error::{TypeError, TypeErrorKind}; use crate::symbol_resolver::SymbolValue; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; @@ -18,7 +20,7 @@ mod test; pub type Type = UnificationKey; #[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct CallId(usize); +pub struct CallId(pub(super) usize); pub type Mapping = HashMap; type VarMap = Mapping; @@ -29,6 +31,7 @@ pub struct Call { pub kwargs: HashMap, pub ret: Type, pub fun: RefCell>, + pub loc: Option, } #[derive(Clone)] @@ -45,23 +48,76 @@ pub struct FunSignature { pub vars: VarMap, } -#[derive(Clone)] -pub enum TypeVarMeta { - Generic, - Sequence(RefCell>), - Record(RefCell>), +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RecordKey { + Str(StrRef), + Int(i32) +} + +impl From<&RecordKey> for StrRef { + fn from(r: &RecordKey) -> Self { + match r { + RecordKey::Str(s) => *s, + RecordKey::Int(i) => StrRef::from(i.to_string()) + } + } +} + +impl From for RecordKey { + fn from(s: StrRef) -> Self { + RecordKey::Str(s) + } +} + +impl From<&str> for RecordKey { + fn from(s: &str) -> Self { + RecordKey::Str(s.into()) + } +} + +impl From for RecordKey { + fn from(i: i32) -> Self { + RecordKey::Int(i) + } +} + +impl Display for RecordKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RecordKey::Str(s) => write!(f, "{}", s), + RecordKey::Int(i) => write!(f, "{}", i) + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct RecordField { + ty: Type, + mutable: bool, + loc: Option +} + +impl RecordField { + pub fn new(ty: Type, mutable: bool, loc: Option) -> RecordField { + RecordField { ty, mutable, loc } + } } #[derive(Clone)] pub enum TypeEnum { TRigidVar { id: u32, + name: Option, + loc: Option }, TVar { id: u32, - meta: TypeVarMeta, + // empty indicates this is not a struct/tuple/list + fields: Option>, // empty indicates no restriction - range: RefCell>, + range: Vec, + name: Option, + loc: Option }, TTuple { ty: Vec, @@ -71,14 +127,14 @@ pub enum TypeEnum { }, TObj { obj_id: DefinitionId, - fields: RefCell>, - params: RefCell, + fields: Mapping, + params: VarMap, }, TVirtual { ty: Type, }, - TCall(RefCell>), - TFunc(RefCell), + TCall(Vec), + TFunc(FunSignature), } impl TypeEnum { @@ -102,7 +158,7 @@ pub type SharedUnifier = Arc, u32, Vec)> pub struct Unifier { pub top_level: Option>, unification_table: UnificationTable>, - calls: Vec>, + pub(super) calls: Vec>, var_id: u32, unify_cache: HashSet<(Type, Type)>, } @@ -125,6 +181,10 @@ impl Unifier { } } + pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable> { + &mut self.unification_table + } + /// Determine if the two types are the same pub fn unioned(&mut self, a: Type, b: Type) -> bool { self.unification_table.unioned(a, b) @@ -155,13 +215,15 @@ impl Unifier { self.unification_table.new_key(Rc::new(a)) } - pub fn add_record(&mut self, fields: Mapping) -> Type { + pub fn add_record(&mut self, fields: Mapping) -> Type { let id = self.var_id + 1; self.var_id += 1; self.add_ty(TypeEnum::TVar { id, - range: vec![].into(), - meta: TypeVarMeta::Record(fields.into()), + range: vec![], + fields: Some(fields), + name: None, + loc: None, }) } @@ -174,7 +236,16 @@ impl Unifier { pub fn get_call_signature(&mut self, id: CallId) -> Option { let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap(); if let TypeEnum::TFunc(sign) = &*self.get_ty(fun) { - Some(sign.borrow().clone()) + Some(sign.clone()) + } else { + None + } + } + + pub fn get_call_signature_immutable(&self, id: CallId) -> Option { + let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap(); + if let TypeEnum::TFunc(sign) = &*self.get_ty_immutable(fun) { + Some(sign.clone()) } else { None } @@ -184,37 +255,35 @@ impl Unifier { self.unification_table.get_representative(ty) } - pub fn add_sequence(&mut self, sequence: Mapping) -> Type { - let id = self.var_id + 1; - self.var_id += 1; - self.add_ty(TypeEnum::TVar { - id, - range: vec![].into(), - meta: TypeVarMeta::Sequence(sequence.into()), - }) - } - /// Get the TypeEnum of a type. pub fn get_ty(&mut self, a: Type) -> Rc { self.unification_table.probe_value(a).clone() } - pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) { - let id = self.var_id + 1; - self.var_id += 1; - (self.add_ty(TypeEnum::TRigidVar { id }), id) + pub fn get_ty_immutable(&self, a: Type) -> Rc { + self.unification_table.probe_value_immutable(a).clone() } - pub fn get_fresh_var(&mut self) -> (Type, u32) { - self.get_fresh_var_with_range(&[]) + pub fn get_fresh_rigid_var(&mut self, name: Option, loc: Option) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + (self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id) + } + + pub fn get_dummy_var(&mut self) -> (Type, u32) { + self.get_fresh_var_with_range(&[], None, None) + } + + pub fn get_fresh_var(&mut self, name: Option, loc: Option) -> (Type, u32) { + self.get_fresh_var_with_range(&[], name, loc) } /// Get a fresh type variable. - pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) { + pub fn get_fresh_var_with_range(&mut self, range: &[Type], name: Option, loc: Option) -> (Type, u32) { let id = self.var_id + 1; self.var_id += 1; - let range = range.to_vec().into(); - (self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id) + let range = range.to_vec(); + (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc}), id) } /// Unification would not unify rigid variables with other types, but we want to do this for @@ -227,7 +296,6 @@ impl Unifier { pub fn get_instantiations(&mut self, ty: Type) -> Option> { match &*self.get_ty(ty) { TypeEnum::TVar { range, .. } => { - let range = range.borrow(); if range.is_empty() { None } else { @@ -261,11 +329,10 @@ impl Unifier { } } TypeEnum::TObj { params, .. } => { - let params = params.borrow(); - let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip(); + let (keys, params): (Vec, Vec) = params.iter().unzip(); let params = params .into_iter() - .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .map(|ty| self.get_instantiations(ty).unwrap_or_else(|| vec![ty])) .multi_cartesian_product() .collect_vec(); if params.len() <= 1 { @@ -277,7 +344,7 @@ impl Unifier { .map(|params| { self.subst( ty, - &zip(keys.iter().cloned().cloned(), params.iter().cloned()) + &zip(keys.iter().cloned(), params.iter().cloned()) .collect(), ) .unwrap_or(ty) @@ -299,7 +366,7 @@ impl Unifier { TList { ty } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { - vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars)) + vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) } // functions are instantiated for each call sites, so the function type can contain // type variables. @@ -314,8 +381,8 @@ impl Unifier { b: Type, signature: &FunSignature, required: &[StrRef], - ) -> Result<(), String> { - let Call { posargs, kwargs, ret, fun } = call; + ) -> Result<(), TypeError> { + let Call { posargs, kwargs, ret, fun, loc } = call; let instantiated = self.instantiate_fun(b, &*signature); let r = self.get_ty(instantiated); let r = r.as_ref(); @@ -329,15 +396,22 @@ impl Unifier { // arguments) are provided, and do not provide the same argument twice. let mut required = required.to_vec(); let mut all_names: Vec<_> = - signature.borrow().args.iter().map(|v| (v.name, v.ty)).rev().collect(); + signature.args.iter().map(|v| (v.name, v.ty)).rev().collect(); for (i, t) in posargs.iter().enumerate() { - if signature.borrow().args.len() <= i { - return Err("Too many arguments.".to_string()); + if signature.args.len() <= i { + return Err(TypeError::new(TypeErrorKind::TooManyArguments{ + expected: signature.args.len(), + got: i, + }, *loc)); } - if !required.is_empty() { - required.pop(); - } - self.unify_impl(all_names.pop().unwrap().1, *t, false)?; + required.pop(); + let (name, expected) = all_names.pop().unwrap(); + self.unify_impl(expected, *t, false) + .map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { + name, + expected, + got: *t, + }, *loc))?; } for (k, t) in kwargs.iter() { if let Some(i) = required.iter().position(|v| v == k) { @@ -346,18 +420,30 @@ impl Unifier { let i = all_names .iter() .position(|v| &v.0 == k) - .ok_or_else(|| format!("Unknown keyword argument {}", k))?; - self.unify_impl(all_names.remove(i).1, *t, false)?; + .ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; + let (name, expected) = all_names.remove(i); + self.unify_impl(expected, *t, false) + .map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { + name, + expected, + got: *t, + }, *loc))?; } if !required.is_empty() { - return Err("Expected more arguments".to_string()); + return Err(TypeError::new(TypeErrorKind::MissingArgs(required.iter().join(", ")), *loc)); } - self.unify_impl(*ret, signature.borrow().ret, false)?; + self.unify_impl(*ret, signature.ret, false) + .map_err(|mut err| { + if err.loc.is_none() { + err.loc = *loc; + } + err + })?; *fun.borrow_mut() = Some(instantiated); Ok(()) } - pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { + pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> { self.unify_cache.clear(); if self.unification_table.unioned(a, b) { Ok(()) @@ -366,9 +452,8 @@ impl Unifier { } } - fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> { + fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), TypeError> { use TypeEnum::*; - use TypeVarMeta::*; if !swapped { let rep_a = self.unification_table.get_representative(a); @@ -386,62 +471,48 @@ impl Unifier { ) }; match (&*ty_a, &*ty_b) { - (TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => { - match (meta1, meta2) { - (Generic, _) => {} - (_, Generic) => { + (TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields2, name: name2, loc: loc2, .. }) => { + let new_fields = match (fields1, fields2) { + (None, None) => None, + (None, Some(fields)) => Some(fields.clone()), + (_, None) => { return self.unify_impl(b, a, true); - } - (Record(fields1), Record(fields2)) => { - let mut fields2 = fields2.borrow_mut(); - for (key, (ty, is_mutable)) in fields1.borrow().iter() { - if let Some((ty2, is_mutable2)) = fields2.get_mut(key) { - self.unify_impl(*ty2, *ty, false)?; - *is_mutable2 |= *is_mutable; + }, + (Some(fields1), Some(fields2)) => { + let mut new_fields: Mapping<_, _> = fields2.clone(); + for (key, val1) in fields1.iter() { + if let Some(val2) = fields2.get(key) { + self.unify_impl(val1.ty, val2.ty, false) + .map_err(|_| TypeError::new(TypeErrorKind::FieldUnificationError { + field: *key, + types: (val1.ty, val2.ty), + loc: (*loc1, *loc2), + }, None))?; + new_fields.insert(*key, RecordField::new(val1.ty, val1.mutable || val2.mutable, val1.loc.or(val2.loc))); } else { - fields2.insert(*key, (*ty, *is_mutable)); + new_fields.insert(*key, *val1); } } + Some(new_fields) } - (Sequence(map1), Sequence(map2)) => { - let mut map2 = map2.borrow_mut(); - for (key, value) in map1.borrow().iter() { - if let Some(ty) = map2.get(key) { - self.unify_impl(*ty, *value, false)?; - } else { - map2.insert(*key, *value); - } - } - } - _ => { - return Err("Incompatible type variables".to_string()); - } - } - let range1 = range1.borrow(); - // new range is the intersection of them - // empty range indicates no constraint - if !range1.is_empty() { - let old_range2 = range2.take(); - let mut range2 = range2.borrow_mut(); - if old_range2.is_empty() { - range2.extend_from_slice(&range1); - } - for v1 in old_range2.iter() { - for v2 in range1.iter() { - if let Ok(result) = self.get_intersection(*v1, *v2) { - range2.push(result.unwrap_or(*v2)); - } - } - } - if range2.is_empty() { - return Err( - "cannot unify type variables with incompatible value range".to_string() - ); - } - } - self.set_a_to_b(a, b); + }; + let intersection = self.get_intersection(a, b).map_err(|_| + TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?.unwrap(); + let range = if let TypeEnum::TVar { range, .. } = &*self.get_ty(intersection) { + range.clone() + } else { + unreachable!() + }; + self.unification_table.unify(a, b); + self.unification_table.set_value(a, Rc::new(TypeEnum::TVar { + id: *id, + fields: new_fields, + range, + name: name1.or(*name2), + loc: loc1.or(*loc2) + })); } - (TVar { meta: Generic, id, range, .. }, _) => { + (TVar { fields: None, range, .. }, _) => { // We check for the range of the type variable to see if unification is allowed. // Note that although b may be compatible with a, we may have to constrain type // variables in b to make sure that instantiations of b would always be compatible @@ -449,42 +520,50 @@ impl Unifier { // The return value x of check_var_compatibility would be a new type that is // guaranteed to be compatible with a under all possible instantiations. So we // unify x with b to recursively apply the constrains, and then set a to x. - let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + let x = self.check_var_compatibility(b, range).map_err(|_| + TypeError::new(TypeErrorKind::IncompatibleRange(b, range.clone()), None))?.unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { + (TVar { fields: Some(fields), range, .. }, TTuple { ty }) => { let len = ty.len() as i32; - for (k, v) in map.borrow().iter() { - // handle negative index - let ind = if *k < 0 { len + *k } else { *k }; - if ind >= len || ind < 0 { - return Err(format!( - "Tuple index out of range. (Length: {}, Index: {})", - len, k - )); + for (k, v) in fields.iter() { + match *k { + RecordKey::Int(i) => { + if v.mutable { + return Err(TypeError::new( + TypeErrorKind::MutationError(*k, b), v.loc)); + } + let ind = if i < 0 { len + i } else { i }; + if ind >= len || ind < 0 { + return Err(TypeError::new( + TypeErrorKind::TupleIndexOutOfBounds{ index: i, len}, v.loc)); + } + self.unify_impl(v.ty, ty[ind as usize], false).map_err(|e| e.at(v.loc))?; + } + RecordKey::Str(_) => return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), v.loc)), } - self.unify_impl(*v, ty[ind as usize], false)?; } - let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + let x = self.check_var_compatibility(b, range)?.unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { - for v in map.borrow().values() { - self.unify_impl(*v, *ty, false)?; + (TVar { fields: Some(fields), range, .. }, TList { ty }) => { + for (k, v) in fields.iter() { + match *k { + RecordKey::Int(_) => self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?, + RecordKey::Str(_) => return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), v.loc)), + } } - let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + let x = self.check_var_compatibility(b, range)?.unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { if ty1.len() != ty2.len() { - return Err(format!( - "Cannot unify tuples with length {} and {}", - ty1.len(), - ty2.len() - )); + return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); } for (x, y) in ty1.iter().zip(ty2.iter()) { self.unify_impl(*x, *y, false)?; @@ -495,47 +574,64 @@ impl Unifier { self.unify_impl(*ty1, *ty2, false)?; self.set_a_to_b(a, b); } - (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { - for (k, (ty, is_mutable)) in map.borrow().iter() { - let (ty2, is_mutable2) = fields - .borrow() - .get(k) - .copied() - .ok_or_else(|| format!("No such attribute {}", k))?; - // typevar represents the usage of the variable - // it is OK to have immutable usage for mutable fields - // but cannot have mutable usage for immutable fields - if *is_mutable && !is_mutable2 { - return Err(format!("Field {} should be immutable", k)); + (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { + for (k, field) in map.iter() { + match *k { + RecordKey::Str(s) => { + let (ty, mutable) = fields + .get(&s) + .copied() + .ok_or_else(|| TypeError::new( + TypeErrorKind::NoSuchField(*k, b), field.loc))?; + // typevar represents the usage of the variable + // it is OK to have immutable usage for mutable fields + // but cannot have mutable usage for immutable fields + if field.mutable && !mutable{ + return Err(TypeError::new( + TypeErrorKind::MutationError(*k, b), field.loc)); + } + self.unify_impl(field.ty, ty, false) + .map_err(|v| v.at(field.loc))?; + } + RecordKey::Int(_) => return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), field.loc)) } - self.unify_impl(*ty, ty2, false)?; } - let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + let x = self.check_var_compatibility(b, range)?.unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { + (TVar { fields: Some(map), range, .. }, TVirtual { ty }) => { let ty = self.get_ty(*ty); if let TObj { fields, .. } = ty.as_ref() { - for (k, (ty, is_mutable)) in map.borrow().iter() { - let (ty2, is_mutable2) = fields - .borrow() - .get(k) - .copied() - .ok_or_else(|| format!("No such attribute {}", k))?; - if !matches!(self.get_ty(ty2).as_ref(), TFunc { .. }) { - return Err(format!("Cannot access field {} for virtual type", k)); + for (k, field) in map.iter() { + match *k { + RecordKey::Str(s) => { + let (ty, _) = fields + .get(&s) + .copied() + .ok_or_else(|| TypeError::new( + TypeErrorKind::NoSuchField(*k, b), field.loc))?; + if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { + return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), field.loc)) + } + if field.mutable { + return Err(TypeError::new( + TypeErrorKind::MutationError(*k, b), field.loc)); + } + self.unify_impl(field.ty, ty, false) + .map_err(|v| v.at(field.loc))?; + } + RecordKey::Int(_) => return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), field.loc)) } - if *is_mutable && !is_mutable2 { - return Err(format!("Field {} should be immutable", k)); - } - self.unify_impl(*ty, ty2, false)?; } } else { // require annotation... - return Err("Requires type annotation for virtual".to_string()); + return Err(TypeError::new(TypeErrorKind::RequiresTypeAnn, None)) } - let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + let x = self.check_var_compatibility(b, range)?.unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } @@ -546,7 +642,7 @@ impl Unifier { if id1 != id2 { self.incompatible_types(a, b)?; } - for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) { + for (x, y) in zip(params1.values(), params2.values()) { self.unify_impl(*x, *y, false)?; } self.set_a_to_b(a, b); @@ -558,11 +654,12 @@ impl Unifier { (TCall(calls1), TCall(calls2)) => { // we do not unify individual calls, instead we defer until the unification wtih a // function definition. - calls2.borrow_mut().extend_from_slice(&calls1.borrow()); + let calls = calls1.iter().chain(calls2.iter()).cloned().collect(); + self.set_a_to_b(a, b); + self.unification_table.set_value(b, Rc::new(TCall(calls))); } (TCall(calls), TFunc(signature)) => { let required: Vec = signature - .borrow() .args .iter() .filter(|v| v.default_value.is_none()) @@ -570,33 +667,32 @@ impl Unifier { .rev() .collect(); // we unify every calls to the function signature. - let signature = signature.borrow(); - for c in calls.borrow().iter() { + for c in calls.iter() { let call = self.calls[c.0].clone(); - self.unify_call(&call, b, &signature, &required)?; + self.unify_call(&call, b, signature, &required)?; } self.set_a_to_b(a, b); } (TFunc(sign1), TFunc(sign2)) => { - let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow()); if !sign1.vars.is_empty() || !sign2.vars.is_empty() { - return Err("Polymorphic function pointer is prohibited.".to_string()); + return Err(TypeError::new(TypeErrorKind::PolymorphicFunctionPointer, None)); } if sign1.args.len() != sign2.args.len() { - return Err("Functions differ in number of parameters.".to_string()); + return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); } for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { - if x.name != y.name { - return Err("Functions differ in parameter names.".to_string()); - } - if x.default_value != y.default_value { - return Err("Functions differ in optional parameters value".to_string()); + if x.name != y.name || x.default_value != y.default_value { + return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); } self.unify_impl(x.ty, y.ty, false)?; } self.unify_impl(sign1.ret, sign2.ret, false)?; self.set_a_to_b(a, b); } + (TVar { fields: Some(fields), .. }, _) => { + let (k, v) = fields.iter().next().unwrap(); + return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)); + } _ => { if swapped { return self.incompatible_types(a, b); @@ -608,9 +704,13 @@ impl Unifier { Ok(()) } - pub fn default_stringify(&mut self, ty: Type) -> String { + pub fn stringify(&self, ty: Type) -> String { + self.stringify_with_notes(ty, &mut None) + } + + pub fn stringify_with_notes(&self, ty: Type, notes: &mut Option>) -> String { let top_level = self.top_level.clone(); - self.stringify( + self.internal_stringify( ty, &mut |id| { top_level.as_ref().map_or_else( @@ -627,54 +727,50 @@ impl Unifier { ) }, &mut |id| format!("var{}", id), + notes ) } /// Get string representation of the type - pub fn stringify(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String + pub fn internal_stringify(&self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G, notes: &mut Option>) -> String where F: FnMut(usize) -> String, G: FnMut(u32) -> String, { - use TypeVarMeta::*; - let ty = self.unification_table.probe_value(ty).clone(); + let ty = self.unification_table.probe_value_immutable(ty).clone(); match ty.as_ref() { - TypeEnum::TRigidVar { id } => var_to_name(*id), - TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id), - TypeEnum::TVar { meta: Sequence(map), .. } => { - let fields = map - .borrow() - .iter() - .map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))) - .join(", "); - format!("seq[{}]", fields) - } - TypeEnum::TVar { meta: Record(fields), .. } => { - let fields = fields - .borrow() - .iter() - .map(|(k, (v, _))| { - format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name)) - }) - .join(", "); - format!("record[{}]", fields) + TypeEnum::TRigidVar { id, name, .. } => name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), + TypeEnum::TVar { id, name, fields, range, .. } => { + let n = if let Some(fields) = fields { + let mut fields = fields.iter().map(|(k, f)| format!("{}={}", k, self.internal_stringify(f.ty, obj_to_name, var_to_name, notes))); + let fields = fields.join(", "); + format!("{}[{}]", name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), fields) + } else { + name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)) + }; + if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id) { + // just in case if there is any cyclic dependency + notes.as_mut().unwrap().insert(*id, "".into()); + let body = format!("{} ∈ {{{}}}", n, range.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)).collect::>().join(", ")); + notes.as_mut().unwrap().insert(*id, body); + }; + n } TypeEnum::TTuple { ty } => { - let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name)); + let mut fields = ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); format!("tuple[{}]", fields.join(", ")) } TypeEnum::TList { ty } => { - format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name)) + format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } TypeEnum::TVirtual { ty } => { - format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name)) + format!("virtual[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } TypeEnum::TObj { obj_id, params, .. } => { let name = obj_to_name(obj_id.0); - let params = params.borrow(); if !params.is_empty() { - let params = params.iter().map(|(id, v)| { - format!("{}->{}", *id, self.stringify(*v, obj_to_name, var_to_name)) + let params = params.iter().map(|(_, v)| { + self.internal_stringify(*v, obj_to_name, var_to_name, notes) }); // sort to preserve order let mut params = params.sorted(); @@ -686,14 +782,17 @@ impl Unifier { TypeEnum::TCall { .. } => "call".to_owned(), TypeEnum::TFunc(signature) => { let params = signature - .borrow() .args .iter() .map(|arg| { - format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name)) + if let Some(dv) = &arg.default_value { + format!("{}:{}={}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), dv) + } else { + format!("{}:{}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)) + } }) .join(", "); - let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name); + let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); format!("fn[[{}], {}]", params, ret) } } @@ -707,12 +806,8 @@ impl Unifier { table.set_value(a, ty_b) } - fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), String> { - Err(format!( - "Cannot unify {} with {}", - self.default_stringify(a), - self.default_stringify(b) - )) + fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> { + Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)) } /// Instantiate a function if it hasn't been instantiated. @@ -722,7 +817,7 @@ impl Unifier { let mut instantiated = true; let mut vars = Vec::new(); for (k, v) in fun.vars.iter() { - if let TypeEnum::TVar { id, range, .. } = + if let TypeEnum::TVar { id, name, loc, range, .. } = self.unification_table.probe_value(*v).as_ref() { // for class methods that contain type vars not in class declaration, @@ -730,7 +825,7 @@ impl Unifier { // and need to do substitution on those type vars if k == id { instantiated = false; - vars.push((*k, range.clone())); + vars.push((*k, range.clone(), *name, *loc)); } } } @@ -739,7 +834,7 @@ impl Unifier { } else { let mapping = vars .into_iter() - .map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0)) + .map(|(k, range, name, loc)| (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0)) .collect(); self.subst(ty, &mapping).unwrap_or(ty) } @@ -762,7 +857,7 @@ impl Unifier { let cached = cache.get_mut(&a); if let Some(cached) = cached { if cached.is_none() { - *cached = Some(self.get_fresh_var().0); + *cached = Some(self.get_fresh_var(None, None).0); } return *cached; } @@ -799,7 +894,6 @@ impl Unifier { // If the mapping does not contain any type variables in the // parameter list, we don't need to substitute the fields. // This is also used to prevent infinite substitution... - let params = params.borrow(); let need_subst = params.values().any(|v| { let ty = self.unification_table.probe_value(*v); if let TypeEnum::TVar { id, .. } = ty.as_ref() { @@ -812,15 +906,11 @@ impl Unifier { cache.insert(a, None); let obj_id = *obj_id; let params = - self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone()); + self.subst_map(params, mapping, cache).unwrap_or_else(|| params.clone()); let fields = self - .subst_map2(&fields.borrow(), mapping, cache) - .unwrap_or_else(|| fields.borrow().clone()); - let new_ty = self.add_ty(TypeEnum::TObj { - obj_id, - params: params.into(), - fields: fields.into(), - }); + .subst_map2(fields, mapping, cache) + .unwrap_or_else(|| fields.clone()); + let new_ty = self.add_ty(TypeEnum::TObj { obj_id, params, fields }); if let Some(var) = cache.get(&a).unwrap() { self.unify_impl(new_ty, *var, false).unwrap(); } @@ -829,8 +919,7 @@ impl Unifier { None } } - TypeEnum::TFunc(sig) => { - let FunSignature { args, ret, vars: params } = &*sig.borrow(); + TypeEnum::TFunc(FunSignature { args, ret, vars: params }) => { let new_params = self.subst_map(params, mapping, cache); let new_ret = self.subst_impl(*ret, mapping, cache); let mut new_args = Cow::from(args); @@ -845,11 +934,7 @@ impl Unifier { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.into_owned(); - Some( - self.add_ty(TypeEnum::TFunc( - FunSignature { args, ret, vars: params }.into(), - )), - ) + Some( self.add_ty(TypeEnum::TFunc( FunSignature { args, ret, vars: params })),) } else { None } @@ -907,40 +992,28 @@ impl Unifier { let x = self.get_ty(a); let y = self.get_ty(b); match (x.as_ref(), y.as_ref()) { - (TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => { - // we should restrict range2 - let range1 = range1.borrow(); + (TVar { range: range1, name, loc, .. }, TVar { fields, range: range2, name: name2, loc: loc2, .. }) => { // new range is the intersection of them // empty range indicates no constraint - if !range1.is_empty() { - let range2 = range2.borrow(); - let mut range = Vec::new(); - if range2.is_empty() { - range.extend_from_slice(&range1); - } - for v1 in range2.iter() { - for v2 in range1.iter() { - let result = self.get_intersection(*v1, *v2); - if let Ok(result) = result { - range.push(result.unwrap_or(*v2)); - } - } - } + if range1.is_empty() { + Ok(Some(b)) + } else if range2.is_empty() { + Ok(Some(a)) + } else { + let range = range2.iter().cartesian_product(range1.iter()) + .filter_map(|(v1, v2)| self.get_intersection(*v1, *v2).map(|v| v.unwrap_or(*v1)).ok()).collect_vec(); if range.is_empty() { Err(()) } else { let id = self.var_id + 1; self.var_id += 1; - let ty = TVar { id, meta: meta.clone(), range: range.into() }; + let ty = TVar { id, fields: fields.clone(), range, name: name2.or(*name), loc: loc2.or(*loc) }; Ok(Some(self.unification_table.new_key(ty.into()))) } - } else { - Ok(Some(b)) } } (_, TVar { range, .. }) => { // range should be restricted to the left hand side - let range = range.borrow(); if range.is_empty() { Ok(Some(a)) } else { @@ -953,24 +1026,13 @@ impl Unifier { Err(()) } } - (TVar { id, range, .. }, _) => { - self.check_var_compatibility(*id, b, &range.borrow()).or(Err(())) + (TVar { range, .. }, _) => { + self.check_var_compatibility(b, range).or(Err(())) } - (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { - if ty1.len() != ty2.len() { - return Err(()); - } - let mut need_new = false; - let mut ty = ty1.clone(); - for (a, b) in zip(ty1.iter(), ty2.iter()) { - let result = self.get_intersection(*a, *b)?; - ty.push(result.unwrap_or(*a)); - if result.is_some() { - need_new = true; - } - } - if need_new { - Ok(Some(self.add_ty(TTuple { ty }))) + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => { + let ty: Vec<_> = zip(ty1.iter(), ty2.iter()).map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?; + if ty.iter().any(Option::is_some) { + Ok(Some(self.add_ty(TTuple { ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect()}))) } else { Ok(None) } @@ -981,12 +1043,8 @@ impl Unifier { (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } - (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => { - if id1 == id2 { - Ok(None) - } else { - Err(()) - } + (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) if id1 == id2 => { + Ok(None) } // don't deal with function shape for now _ => Err(()), @@ -995,10 +1053,9 @@ impl Unifier { fn check_var_compatibility( &mut self, - id: u32, b: Type, range: &[Type], - ) -> Result, String> { + ) -> Result, TypeError> { if range.is_empty() { return Ok(None); } @@ -1008,10 +1065,6 @@ impl Unifier { return Ok(result); } } - return Err(format!( - "Cannot unify variable {} with {} due to incompatible value range", - id, - self.default_stringify(b) - )); + Err(TypeError::new(TypeErrorKind::IncompatibleRange(b, range.to_vec()), None)) } } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index e4fbee6..56e0fd3 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -1,4 +1,5 @@ use super::*; +use super::super::magic_methods::with_fields; use indoc::indoc; use itertools::Itertools; use std::collections::HashMap; @@ -7,7 +8,6 @@ use test_case::test_case; impl Unifier { /// Check whether two types are equal. fn eq(&mut self, a: Type, b: Type) -> bool { - use TypeVarMeta::*; if a == b { return true; } @@ -21,13 +21,13 @@ impl Unifier { match (&*ty_a, &*ty_b) { ( - TypeEnum::TVar { meta: Generic, id: id1, .. }, - TypeEnum::TVar { meta: Generic, id: id2, .. }, + TypeEnum::TVar { fields: None, id: id1, .. }, + TypeEnum::TVar { fields: None, id: id2, .. }, ) => id1 == id2, ( - TypeEnum::TVar { meta: Sequence(map1), .. }, - TypeEnum::TVar { meta: Sequence(map2), .. }, - ) => self.map_eq(&map1.borrow(), &map2.borrow()), + TypeEnum::TVar { fields: Some(map1), .. }, + TypeEnum::TVar { fields: Some(map2), .. }, + ) => self.map_eq2(map1, map2), (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { ty1.len() == ty2.len() && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) @@ -36,14 +36,10 @@ impl Unifier { | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { self.eq(*ty1, *ty2) } - ( - TypeEnum::TVar { meta: Record(fields1), .. }, - TypeEnum::TVar { meta: Record(fields2), .. }, - ) => self.map_eq2(&fields1.borrow(), &fields2.borrow()), ( TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. }, - ) => id1 == id2 && self.map_eq(¶ms1.borrow(), ¶ms2.borrow()), + ) => id1 == id2 && self.map_eq(params1, params2), // TCall and TFunc are not yet implemented _ => false, } @@ -64,19 +60,15 @@ impl Unifier { true } - fn map_eq2( - &mut self, - map1: &Mapping, - map2: &Mapping, - ) -> bool + fn map_eq2(&mut self, map1: &Mapping, map2: &Mapping) -> bool where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { if map1.len() != map2.len() { return false; } - for (k, (ty1, m1)) in map1.iter() { - if !map2.get(k).map(|(ty2, m2)| m1 == m2 && self.eq(*ty1, *ty2)).unwrap_or(false) { + for (k, v) in map1.iter() { + if !map2.get(k).map(|v1| self.eq(v.ty, v1.ty)).unwrap_or(false) { return false; } } @@ -98,27 +90,27 @@ impl TestEnvironment { "int".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }), ); type_mapping.insert( "float".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }), ); type_mapping.insert( "bool".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(2), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }), ); - let (v0, id) = unifier.get_fresh_var(); + let (v0, id) = unifier.get_dummy_var(); type_mapping.insert( "Foo".into(), unifier.add_ty(TypeEnum::TObj { @@ -126,9 +118,8 @@ impl TestEnvironment { fields: [("a".into(), (v0, true))] .iter() .cloned() - .collect::>() - .into(), - params: [(id, v0)].iter().cloned().collect::>().into(), + .collect::>(), + params: [(id, v0)].iter().cloned().collect::>(), }), ); @@ -174,7 +165,7 @@ impl TestEnvironment { let eq = s.find('=').unwrap(); let key = s[1..eq].into(); let result = self.internal_parse(&s[eq + 1..], mapping); - fields.insert(key, (result.0, true)); + fields.insert(key, RecordField::new(result.0, true, None)); s = result.1; } (self.unifier.add_record(fields), &s[1..]) @@ -187,7 +178,6 @@ impl TestEnvironment { let mut ty = *self.type_mapping.get(x).unwrap(); let te = self.unifier.get_ty(ty); if let TypeEnum::TObj { params, .. } = &*te.as_ref() { - let params = params.borrow(); if !params.is_empty() { assert!(&s[0..1] == "["); let mut p = Vec::new(); @@ -209,6 +199,10 @@ impl TestEnvironment { } } } + + fn unify(&mut self, typ1: Type, typ2: Type) -> Result<(), String> { + self.unifier.unify(typ1, typ2).map_err(|e| e.to_display(&self.unifier).to_string()) + } } #[test_case(2, @@ -258,7 +252,7 @@ fn test_unify( let mut env = TestEnvironment::new(); let mut mapping = HashMap::new(); for i in 1..=variable_count { - let v = env.unifier.get_fresh_var(); + let v = env.unifier.get_dummy_var(); mapping.insert(format!("v{}", i), v.0); } // unification may have side effect when we do type resolution, so freeze the types @@ -276,6 +270,7 @@ fn test_unify( println!("{} = {}", a, b); let t1 = env.parse(a, &mapping); let t2 = env.parse(b, &mapping); + println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2)); assert!(env.unifier.eq(t1, t2)); } } @@ -286,7 +281,7 @@ fn test_unify( ("v1", "tuple[int]"), ("v2", "list[int]"), ], - (("v1", "v2"), "Cannot unify list[0] with tuple[0]") + (("v1", "v2"), "Incompatible types: list[0] and tuple[0]") ; "type mismatch" )] #[test_case(2, @@ -294,7 +289,7 @@ fn test_unify( ("v1", "tuple[int]"), ("v2", "tuple[float]"), ], - (("v1", "v2"), "Cannot unify 0 with 1") + (("v1", "v2"), "Incompatible types: 0 and 1") ; "tuple parameter mismatch" )] #[test_case(2, @@ -302,7 +297,7 @@ fn test_unify( ("v1", "tuple[int,int]"), ("v2", "tuple[int]"), ], - (("v1", "v2"), "Cannot unify tuples with length 2 and 1") + (("v1", "v2"), "Tuple length mismatch: got tuple[0, 0] and tuple[0]") ; "tuple length mismatch" )] #[test_case(3, @@ -310,7 +305,7 @@ fn test_unify( ("v1", "Record[a=float,b=int]"), ("v2", "Foo[v3]"), ], - (("v1", "v2"), "No such attribute b") + (("v1", "v2"), "`3[var4]::b` field does not exist") ; "record obj merge" )] /// Test cases for invalid unifications. @@ -322,7 +317,7 @@ fn test_invalid_unification( let mut env = TestEnvironment::new(); let mut mapping = HashMap::new(); for i in 1..=variable_count { - let v = env.unifier.get_fresh_var(); + let v = env.unifier.get_dummy_var(); mapping.insert(format!("v{}", i), v.0); } // unification may have side effect when we do type resolution, so freeze the types @@ -338,7 +333,7 @@ fn test_invalid_unification( for (a, b) in pairs { env.unifier.unify(a, b).unwrap(); } - assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); + assert_eq!(env.unify(t1, t2), Err(errornous_pair.1.to_string())); } #[test] @@ -348,16 +343,17 @@ fn test_recursive_subst() { let foo_id = *env.type_mapping.get("Foo").unwrap(); let foo_ty = env.unifier.get_ty(foo_id); let mapping: HashMap<_, _>; - if let TypeEnum::TObj { fields, params, .. } = &*foo_ty { - fields.borrow_mut().insert("rec".into(), (foo_id, true)); - mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect(); + with_fields(&mut env.unifier, foo_id, |_unifier, fields| { + fields.insert("rec".into(), (foo_id, true)); + }); + if let TypeEnum::TObj { params, .. } = &*foo_ty { + mapping = params.iter().map(|(id, _)| (*id, int)).collect(); } else { unreachable!() } let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated_ty = env.unifier.get_ty(instantiated); if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { - let fields = fields.borrow(); assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); } else { @@ -370,32 +366,31 @@ fn test_virtual() { let mut env = TestEnvironment::new(); let int = env.parse("int", &HashMap::new()); let fun = env.unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(), + FunSignature { args: vec![], ret: int, vars: HashMap::new() }, )); let bar = env.unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] .iter() .cloned() - .collect::>() - .into(), - params: HashMap::new().into(), + .collect::>(), + params: HashMap::new(), }); - let v0 = env.unifier.get_fresh_var().0; - let v1 = env.unifier.get_fresh_var().0; + let v0 = env.unifier.get_dummy_var().0; + let v1 = env.unifier.get_dummy_var().0; let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); - let c = env.unifier.add_record([("f".into(), (v1, false))].iter().cloned().collect()); + let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect()); env.unifier.unify(a, b).unwrap(); env.unifier.unify(b, c).unwrap(); assert!(env.unifier.eq(v1, fun)); - let d = env.unifier.add_record([("a".into(), (v1, true))].iter().cloned().collect()); - assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); + let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); + assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field does not exist".to_string())); - let d = env.unifier.add_record([("b".into(), (v1, true))].iter().cloned().collect()); - assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); + let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); + assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field does not exist".to_string())); } #[test] @@ -409,107 +404,107 @@ fn test_typevar_range() { // unification between v and int // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; env.unifier.unify(int, v).unwrap(); // unification between v and list[int] // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; assert_eq!( - env.unifier.unify(int_list, v), - Err("Cannot unify variable 3 with list[0] due to incompatible value range".to_string()) + env.unify(int_list, v), + Err("Expected any one of these types: 0, 2, but got list[0]".to_string()) ); // unification between v and float // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; assert_eq!( - env.unifier.unify(float, v), - Err("Cannot unify variable 4 with 1 due to incompatible value range".to_string()) + env.unify(float, v), + Err("Expected any one of these types: 0, 2, but got 1".to_string()) ); - let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; // unification between v and int // where v in (int, list[v1]), v1 in (int, bool) env.unifier.unify(int, v).unwrap(); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; // unification between v and list[int] // where v in (int, list[v1]), v1 in (int, bool) env.unifier.unify(int_list, v).unwrap(); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; // unification between v and list[float] // where v in (int, list[v1]), v1 in (int, bool) assert_eq!( - env.unifier.unify(float_list, v), - Err("Cannot unify variable 8 with list[1] due to incompatible value range".to_string()) + env.unify(float_list, v), + Err("Expected any one of these types: 0, list[var5], but got list[1]\n\nNotes:\n var5 ∈ {0, 2}".to_string()) ); - let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, float).unwrap(); - let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; env.unifier.unify(a, b).unwrap(); assert_eq!( - env.unifier.unify(a, int), - Err("Cannot unify variable 12 with 0 due to incompatible value range".into()) + env.unify(a, int), + Err("Expected any one of these types: 1, but got 0".into()) ); - let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); - let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); - let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0; + let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0; env.unifier.unify(a_list, b_list).unwrap(); let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); env.unifier.unify(a_list, float_list).unwrap(); // previous unifications should not affect a and b env.unifier.unify(a, int).unwrap(); - let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int }); assert_eq!( - env.unifier.unify(a_list, int_list), - Err("Cannot unify variable 19 with 0 due to incompatible value range".into()) + env.unify(a_list, int_list), + Err("Expected any one of these types: 1, but got 0".into()) ); - let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; - let b = env.unifier.get_fresh_var().0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; + let b = env.unifier.get_dummy_var().0; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); - let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); assert_eq!( - env.unifier.unify(b, boolean), - Err("Cannot unify variable 21 with 2 due to incompatible value range".into()) + env.unify(b, boolean), + Err("Expected any one of these types: 0, 1, but got 2".into()) ); } #[test] fn test_rigid_var() { let mut env = TestEnvironment::new(); - let a = env.unifier.get_fresh_rigid_var().0; - let b = env.unifier.get_fresh_rigid_var().0; - let x = env.unifier.get_fresh_var().0; + let a = env.unifier.get_fresh_rigid_var(None, None).0; + let b = env.unifier.get_fresh_rigid_var(None, None).0; + let x = env.unifier.get_dummy_var().0; let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); let int = env.parse("int", &HashMap::new()); let list_int = env.parse("list[int]", &HashMap::new()); - assert_eq!(env.unifier.unify(a, b), Err("Cannot unify var3 with var2".to_string())); + assert_eq!(env.unify(a, b), Err("Incompatible types: var3 and var2".to_string())); env.unifier.unify(list_a, list_x).unwrap(); - assert_eq!(env.unifier.unify(list_x, list_int), Err("Cannot unify 0 with var2".to_string())); + assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: 0 and var2".to_string())); env.unifier.replace_rigid_var(a, int); env.unifier.unify(list_x, list_int).unwrap(); @@ -526,13 +521,13 @@ fn test_instantiation() { let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); - let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v }); - let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0; - let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0; - let t = env.unifier.get_fresh_rigid_var().0; + let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0; + let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0; + let t = env.unifier.get_dummy_var().0; let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); - let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0; + let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0; // t = TypeVar('t') // v = TypeVar('v', int, bool) // v1 = TypeVar('v1', 'list[v]', int) @@ -561,9 +556,9 @@ fn test_instantiation() { let types = types .iter() .map(|ty| { - env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { + env.unifier.internal_stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { format!("v{}", i) - }) + }, &mut None) }) .sorted() .collect_vec(); diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index b126e09..27df8db 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -46,6 +46,17 @@ impl UnificationTable { } } + pub fn probe_value_immutable(&self, key: UnificationKey) -> &V { + let mut root = key.0; + let mut parent = self.parents[root]; + while root != parent { + root = parent; + // parent = root.parent + parent = self.parents[parent]; + } + self.values[parent].as_ref().unwrap() + } + pub fn probe_value(&mut self, a: UnificationKey) -> &V { let index = self.find(a); self.values[index].as_ref().unwrap() diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 0b079c8..3d1fbaa 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -63,8 +63,8 @@ impl SymbolResolver for Resolver { unimplemented!() } - fn get_identifier_def(&self, id: StrRef) -> Option { - self.0.id_to_def.lock().get(&id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Result { + self.0.id_to_def.lock().get(&id).cloned().ok_or_else(|| "Undefined identifier".to_string()) } fn get_string_id(&self, s: &str) -> i32 { diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index c6dcf56..3d9dada 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -86,7 +86,7 @@ fn main() { get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty) }) .collect::, _>>()?; - Ok(unifier.get_fresh_var_with_range(&constraints).0) + Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) } else { Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) } @@ -219,7 +219,7 @@ fn main() { let mut instance = defs[resolver .get_identifier_def("run".into()) - .unwrap_or_else(|| panic!("cannot find run() entry point")).0 + .unwrap_or_else(|_| panic!("cannot find run() entry point")).0 ].write(); if let TopLevelDef::Function { instance_to_stmt, -- 2.44.1 From f97f93d92cf5d0f8a2508ab58b75069954e6b00a Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 21 Feb 2022 18:27:46 +0800 Subject: [PATCH 3/5] applied rustfmt and clippy auto fix --- nac3artiq/src/codegen.rs | 43 +- nac3artiq/src/lib.rs | 258 ++---- nac3artiq/src/symbol_resolver.rs | 322 +++---- nac3artiq/src/timeline.rs | 182 ++-- nac3core/src/codegen/concrete_type.rs | 53 +- nac3core/src/codegen/expr.rs | 155 +++- nac3core/src/codegen/generator.rs | 35 +- nac3core/src/codegen/irrt/mod.rs | 3 +- nac3core/src/codegen/mod.rs | 91 +- nac3core/src/codegen/stmt.rs | 109 ++- nac3core/src/codegen/test.rs | 15 +- nac3core/src/symbol_resolver.rs | 33 +- nac3core/src/toplevel/builtins.rs | 786 ++++++++++-------- nac3core/src/toplevel/composer.rs | 475 ++++++----- nac3core/src/toplevel/helper.rs | 110 ++- nac3core/src/toplevel/test.rs | 24 +- nac3core/src/toplevel/type_annotation.rs | 161 ++-- nac3core/src/typecheck/magic_methods.rs | 95 +-- nac3core/src/typecheck/mod.rs | 2 +- nac3core/src/typecheck/type_error.rs | 49 +- nac3core/src/typecheck/type_inferencer/mod.rs | 121 ++- .../src/typecheck/type_inferencer/test.rs | 169 ++-- nac3core/src/typecheck/typedef/mod.rs | 343 +++++--- nac3core/src/typecheck/typedef/test.rs | 41 +- nac3standalone/demo/demo.rs | 12 +- nac3standalone/src/basic_symbol_resolver.rs | 6 +- nac3standalone/src/main.rs | 112 +-- 27 files changed, 2038 insertions(+), 1767 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 6aca71d..9e94956 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -130,18 +130,21 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { // the LLVM Context. // The name is guaranteed to be unique as users cannot use this as variable // name. - self.start = old_start.clone().map_or_else(|| { - let start = format!("with-{}-start", self.name_counter).into(); - let start_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, - custom: Some(ctx.primitives.int64), - }; - let start = self.gen_store_target(ctx, &start_expr)?; - ctx.builder.build_store(start, now); - Ok(Some(start_expr)) as Result<_, String> - }, |v| Ok(Some(v)))?; + self.start = old_start.clone().map_or_else( + || { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, + custom: Some(ctx.primitives.int64), + }; + let start = self.gen_store_target(ctx, &start_expr)?; + ctx.builder.build_store(start, now); + Ok(Some(start_expr)) as Result<_, String> + }, + |v| Ok(Some(v)), + )?; let end = format!("with-{}-end", self.name_counter).into(); let end_expr = Located { // location does not matter at this point @@ -179,8 +182,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } // inside a parallel block, should update the outer max now_mu if let Some(old_end) = &old_end { - let outer_end_val = - self.gen_expr(ctx, old_end)?.unwrap().to_basic_value_enum(ctx, self); + let outer_end_val = self + .gen_expr(ctx, old_end)? + .unwrap() + .to_basic_value_enum(ctx, self); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); @@ -226,7 +231,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } } -fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec) -> Result<(), String> { +fn gen_rpc_tag<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: Type, + buffer: &mut Vec, +) -> Result<(), String> { use nac3core::typecheck::typedef::TypeEnum::*; let int32 = ctx.primitives.int32; @@ -283,7 +292,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( let int32 = ctx.ctx.i32_type(); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - let service_id = int32.const_int(fun.1.0 as u64, false); + let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { @@ -433,7 +442,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); - return Ok(None) + return Ok(None); } let prehead_bb = ctx.builder.get_insert_block().unwrap(); diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 5707551..d13fc96 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -10,7 +10,7 @@ use inkwell::{ targets::*, OptimizationLevel, }; -use nac3core::typecheck::typedef::{Unifier, TypeEnum}; +use nac3core::typecheck::typedef::{TypeEnum, Unifier}; use nac3parser::{ ast::{self, ExprKind, Stmt, StmtKind, StrRef}, parser::{self, parse_program}, @@ -21,8 +21,8 @@ use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; use parking_lot::{Mutex, RwLock}; use nac3core::{ - codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, codegen::irrt::load_irrt, + codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{ composer::{ComposerConfig, TopLevelComposer}, @@ -96,10 +96,7 @@ impl Nac3 { ) -> PyResult<()> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let module: &PyAny = module.extract(py)?; - Ok(( - module.getattr("__name__")?.extract()?, - module.getattr("__file__")?.extract()?, - )) + Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) })?; let source = fs::read_to_string(&source_file).map_err(|e| { @@ -111,10 +108,7 @@ impl Nac3 { for mut stmt in parser_result.into_iter() { let include = match stmt.node { ast::StmtKind::ClassDef { - ref decorator_list, - ref mut body, - ref mut bases, - .. + ref decorator_list, ref mut body, ref mut bases, .. } => { let nac3_class = decorator_list.iter().any(|decorator| { if let ast::ExprKind::Name { id, .. } = decorator.node { @@ -146,10 +140,7 @@ impl Nac3 { .unwrap() }); body.retain(|stmt| { - if let ast::StmtKind::FunctionDef { - ref decorator_list, .. - } = stmt.node - { + if let ast::StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { decorator_list.iter().any(|decorator| { if let ast::ExprKind::Name { id, .. } = decorator.node { id.to_string() == "kernel" @@ -165,22 +156,21 @@ impl Nac3 { }); true } - ast::StmtKind::FunctionDef { - ref decorator_list, .. - } => decorator_list.iter().any(|decorator| { - if let ast::ExprKind::Name { id, .. } = decorator.node { - let id = id.to_string(); - id == "extern" || id == "portable" || id == "kernel" || id == "rpc" - } else { - false - } - }), + ast::StmtKind::FunctionDef { ref decorator_list, .. } => { + decorator_list.iter().any(|decorator| { + if let ast::ExprKind::Name { id, .. } = decorator.node { + let id = id.to_string(); + id == "extern" || id == "portable" || id == "kernel" || id == "rpc" + } else { + false + } + }) + } _ => false, }; if include { - self.top_levels - .push((stmt, module_name.clone(), module.clone())); + self.top_levels.push((stmt, module_name.clone(), module.clone())); } } Ok(()) @@ -197,7 +187,7 @@ impl Nac3 { let base_ty = match resolver.get_symbol_type(unifier, top_level_defs, primitives, "base".into()) { Ok(ty) => ty, - Err(e) => return Some(format!("type error inside object launching kernel: {}", e)) + Err(e) => return Some(format!("type error inside object launching kernel: {}", e)), }; let fun_ty = if method_name.is_empty() { @@ -205,12 +195,15 @@ impl Nac3 { } else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) { match fields.get(&(*method_name).into()) { Some(t) => t.0, - None => return Some( - format!("object launching kernel does not have method `{}`", method_name) - ) + None => { + return Some(format!( + "object launching kernel does not have method `{}`", + method_name + )) + } } } else { - return Some("cannot launch kernel by calling a non-callable".into()) + return Some("cannot launch kernel by calling a non-callable".into()); }; if let TypeEnum::TFunc(FunSignature { args, .. }) = &*unifier.get_ty(fun_ty) { @@ -219,35 +212,43 @@ impl Nac3 { "launching kernel function with too many arguments (expect {}, found {})", args.len(), arg_names.len(), - )) + )); } for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() { let in_name = match arg_names.get(i) { Some(n) => n, - None if default_value.is_none() => return Some(format!( - "argument `{}` not provided when launching kernel function", name - )), + None if default_value.is_none() => { + return Some(format!( + "argument `{}` not provided when launching kernel function", + name + )) + } _ => break, }; let in_ty = match resolver.get_symbol_type( unifier, top_level_defs, primitives, - in_name.clone().into() + in_name.clone().into(), ) { Ok(t) => t, - Err(e) => return Some(format!( - "type error ({}) at parameter #{} when calling kernel function", e, i - )) + Err(e) => { + return Some(format!( + "type error ({}) at parameter #{} when calling kernel function", + e, i + )) + } }; if let Err(e) = unifier.unify(in_ty, *ty) { return Some(format!( - "type error ({}) at parameter #{} when calling kernel function", e.to_display(unifier).to_string(), i + "type error ({}) at parameter #{} when calling kernel function", + e.to_display(unifier).to_string(), + i )); } } } else { - return Some("cannot launch kernel by calling a non-callable".into()) + return Some("cannot launch kernel by calling a non-callable".into()); } None } @@ -274,11 +275,7 @@ impl Nac3 { let builtins = vec![ ( "now_mu".into(), - FunSignature { - args: vec![], - ret: primitive.int64, - vars: HashMap::new(), - }, + FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() }, Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Ok(Some(time_fns.emit_now_mu(ctx))) }))), @@ -320,10 +317,7 @@ impl Nac3 { ]; let (_, builtins_def, builtins_ty) = TopLevelComposer::new( builtins.clone(), - ComposerConfig { - kernel_ann: Some("Kernel"), - kernel_invariant_ann: "KernelInvariant", - }, + ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ); let builtins_mod = PyModule::import(py, "builtins").unwrap(); @@ -355,46 +349,22 @@ impl Nac3 { .extract() .unwrap(), ), - none: id_fn - .call1((builtins_mod.getattr("None").unwrap(),)) - .unwrap() - .extract() - .unwrap(), + none: id_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap().extract().unwrap(), typevar: id_fn .call1((typing_mod.getattr("TypeVar").unwrap(),)) .unwrap() .extract() .unwrap(), - int: id_fn - .call1((builtins_mod.getattr("int").unwrap(),)) - .unwrap() - .extract() - .unwrap(), - int32: id_fn - .call1((numpy_mod.getattr("int32").unwrap(),)) - .unwrap() - .extract() - .unwrap(), - int64: id_fn - .call1((numpy_mod.getattr("int64").unwrap(),)) - .unwrap() - .extract() - .unwrap(), - bool: id_fn - .call1((builtins_mod.getattr("bool").unwrap(),)) - .unwrap() - .extract() - .unwrap(), + int: id_fn.call1((builtins_mod.getattr("int").unwrap(),)).unwrap().extract().unwrap(), + int32: id_fn.call1((numpy_mod.getattr("int32").unwrap(),)).unwrap().extract().unwrap(), + int64: id_fn.call1((numpy_mod.getattr("int64").unwrap(),)).unwrap().extract().unwrap(), + bool: id_fn.call1((builtins_mod.getattr("bool").unwrap(),)).unwrap().extract().unwrap(), float: id_fn .call1((builtins_mod.getattr("float").unwrap(),)) .unwrap() .extract() .unwrap(), - list: id_fn - .call1((builtins_mod.getattr("list").unwrap(),)) - .unwrap() - .extract() - .unwrap(), + list: id_fn.call1((builtins_mod.getattr("list").unwrap(),)).unwrap().extract().unwrap(), tuple: id_fn .call1((builtins_mod.getattr("tuple").unwrap(),)) .unwrap() @@ -408,11 +378,7 @@ impl Nac3 { }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); - fs::write( - working_directory.path().join("kernel.ld"), - include_bytes!("kernel.ld"), - ) - .unwrap(); + fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap(); Ok(Nac3 { isa, @@ -425,7 +391,7 @@ impl Nac3 { top_levels: Default::default(), pyid_to_def: Default::default(), working_directory, - string_store: Default::default() + string_store: Default::default(), }) } @@ -465,20 +431,17 @@ impl Nac3 { embedding_map: &PyAny, py: Python, ) -> PyResult<()> { - let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig { - kernel_ann: Some("Kernel"), - kernel_invariant_ann: "KernelInvariant" - }); + let (mut composer, _, _) = TopLevelComposer::new( + self.builtins.clone(), + ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, + ); let builtins = PyModule::import(py, "builtins")?; let typings = PyModule::import(py, "typing")?; let id_fn = builtins.getattr("id")?; let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); - let store_fun = embedding_map - .getattr("store_function") - .unwrap() - .to_object(py); + let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py), @@ -486,7 +449,7 @@ impl Nac3 { origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), store_obj, - store_str + store_str, }; let mut module_to_resolver_cache: HashMap = HashMap::new(); @@ -497,10 +460,8 @@ impl Nac3 { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let helper = helper.clone(); - let (name_to_pyid, resolver) = module_to_resolver_cache - .get(&module_id) - .cloned() - .unwrap_or_else(|| { + let (name_to_pyid, resolver) = + module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = py_module.getattr("__dict__").unwrap().cast_as().unwrap(); @@ -535,7 +496,10 @@ impl Nac3 { let (name, def_id, ty) = composer .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .map_err(|e| { - exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure\n----------\n{}", e)) + exceptions::PyRuntimeError::new_err(format!( + "nac3 compilation failure\n----------\n{}", + e + )) })?; match &stmt.node { @@ -583,16 +547,10 @@ impl Nac3 { let synthesized = if method_name.is_empty() { format!("def __modinit__():\n base({})", arg_names.join(", ")) } else { - format!( - "def __modinit__():\n base.{}({})", - method_name, - arg_names.join(", ") - ) + format!("def __modinit__():\n base.{}({})", method_name, arg_names.join(", ")) }; - let mut synthesized = parse_program( - &synthesized, - "__nac3_synthesized_modinit__".to_string().into(), - ).unwrap(); + let mut synthesized = + parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap(); let resolver = Arc::new(Resolver(Arc::new(InnerResolver { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), @@ -610,34 +568,24 @@ impl Nac3 { string_store: self.string_store.clone(), }))) as Arc; let (_, def_id, _) = composer - .register_top_level( - synthesized.pop().unwrap(), - Some(resolver.clone()), - "".into(), - ) + .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into()) .unwrap(); - let signature = FunSignature { - args: vec![], - ret: self.primitive.none, - vars: HashMap::new(), - }; + let signature = + FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() }; let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); - let signature = store.from_signature( - &mut composer.unifier, - &self.primitive, - &signature, - &mut cache, - ); + let signature = + store.from_signature(&mut composer.unifier, &self.primitive, &signature, &mut cache); let signature = store.add_cty(signature); if let Err(e) = composer.start_analysis(true) { // report error of __modinit__ separately if !e.contains("__nac3_synthesized_modinit__") { - return Err(exceptions::PyRuntimeError::new_err( - format!("nac3 compilation failure: \n----------\n{}", e) - )); + return Err(exceptions::PyRuntimeError::new_err(format!( + "nac3 compilation failure: \n----------\n{}", + e + ))); } else { let msg = Self::report_modinit( &arg_names, @@ -645,7 +593,7 @@ impl Nac3 { resolver.clone(), &composer.extract_def_list(), &mut composer.unifier, - &self.primitive + &self.primitive, ); return Err(exceptions::PyRuntimeError::new_err(msg.unwrap())); } @@ -658,9 +606,7 @@ impl Nac3 { for (class_data, id) in rpc_ids.iter() { let mut def = defs[id.0].write(); match &mut *def { - TopLevelDef::Function { - codegen_callback, .. - } => { + TopLevelDef::Function { codegen_callback, .. } => { *codegen_callback = Some(rpc_codegen.clone()); } TopLevelDef::Class { methods, .. } => { @@ -669,9 +615,8 @@ impl Nac3 { if name != method_name { continue; } - if let TopLevelDef::Function { - codegen_callback, .. - } = &mut *defs[id.0].write() + if let TopLevelDef::Function { codegen_callback, .. } = + &mut *defs[id.0].write() { *codegen_callback = Some(rpc_codegen.clone()); store_fun @@ -693,11 +638,8 @@ impl Nac3 { let instance = { let defs = top_level.definitions.read(); let mut definition = defs[def_id.0].write(); - if let TopLevelDef::Function { - instance_to_stmt, - instance_to_symbol, - .. - } = &mut *definition + if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = + &mut *definition { instance_to_symbol.insert("".to_string(), "__modinit__".into()); instance_to_stmt[""].clone() @@ -733,13 +675,7 @@ impl Nac3 { let thread_names: Vec = (0..4).map(|_| "main".to_string()).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| { - Box::new(ArtiqCodeGenerator::new( - s.to_string(), - size_t, - self.time_fns, - )) - }) + .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) .collect(); py.allow_threads(|| { @@ -784,14 +720,10 @@ impl Nac3 { TargetMachine::get_default_triple(), TargetMachine::get_host_cpu_features().to_string(), ), - Isa::RiscV32G => ( - TargetTriple::create("riscv32-unknown-linux"), - "+a,+m,+f,+d".to_string(), - ), - Isa::RiscV32IMA => ( - TargetTriple::create("riscv32-unknown-linux"), - "+a,+m".to_string(), - ), + Isa::RiscV32G => { + (TargetTriple::create("riscv32-unknown-linux"), "+a,+m,+f,+d".to_string()) + } + Isa::RiscV32IMA => (TargetTriple::create("riscv32-unknown-linux"), "+a,+m".to_string()), Isa::CortexA9 => ( TargetTriple::create("armv7-unknown-linux-gnueabihf"), "+dsp,+fp16,+neon,+vfp3".to_string(), @@ -819,28 +751,18 @@ impl Nac3 { "-x".to_string(), "-o".to_string(), filename.to_string(), - working_directory - .join("module.o") - .to_string_lossy() - .to_string(), + working_directory.join("module.o").to_string_lossy().to_string(), ]; if isa != Isa::Host { linker_args.push( "-T".to_string() - + self - .working_directory - .path() - .join("kernel.ld") - .to_str() - .unwrap(), + + self.working_directory.path().join("kernel.ld").to_str().unwrap(), ); } if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() { if !linker_status.success() { - return Err(exceptions::PyRuntimeError::new_err( - "failed to start linker", - )); + return Err(exceptions::PyRuntimeError::new_err("failed to start linker")); } } else { return Err(exceptions::PyRuntimeError::new_err( diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index db513f8..ad32413 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -83,17 +83,21 @@ impl StaticValue for PythonValue { Python::with_gil(|py| -> PyResult> { let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); - let global = - ctx.module - .add_global(struct_type, None, format!("{}_const", self.id).as_str()); + let global = ctx.module.add_global( + struct_type, + None, + format!("{}_const", self.id).as_str(), + ); global.set_constant(true); global.set_initializer(&ctx.ctx.const_struct( &[ctx.ctx.i32_type().const_int(id as u64, false).into()], false, )); - let global2 = - ctx.module - .add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str()); + let global2 = ctx.module.add_global( + struct_type.ptr_type(AddressSpace::Generic), + None, + format!("{}_const2", self.id).as_str(), + ); global2.set_initializer(&global.as_pointer_value()); Ok(global2.as_pointer_value().into()) }) @@ -160,10 +164,7 @@ impl StaticValue for PythonValue { let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?; Some((id, obj)) }; - self.resolver - .field_to_val - .write() - .insert((self.id, name), result.clone()); + self.resolver.field_to_val.write().insert((self.id, name), result.clone()); Ok(result) }) .unwrap() @@ -191,24 +192,27 @@ impl InnerResolver { ) -> PyResult> { let mut ty = match self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)? { Ok(t) => t, - Err(e) => return Ok(Err(format!( - "type error ({}) at element #0 of the list", e - ))), + Err(e) => return Ok(Err(format!("type error ({}) at element #0 of the list", e))), }; for i in 1..len { let b = match list .get_item(i) - .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))?? { - Ok(t) => t, - Err(e) => return Ok(Err(format!( - "type error ({}) at element #{} of the list", e, i - ))), - }; + .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))?? + { + Ok(t) => t, + Err(e) => { + return Ok(Err(format!("type error ({}) at element #{} of the list", e, i))) + } + }; ty = match unifier.unify(ty, b) { Ok(_) => ty, - Err(e) => return Ok(Err(format!( - "inhomogeneous type ({}) at element #{} of the list", e.to_display(unifier).to_string(), i - ))) + Err(e) => { + return Ok(Err(format!( + "inhomogeneous type ({}) at element #{} of the list", + e.to_display(unifier).to_string(), + i + ))) + } }; } Ok(Ok(ty)) @@ -227,11 +231,8 @@ impl InnerResolver { primitives: &PrimitiveStore, ) -> PyResult> { let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; - let ty_ty_id: u64 = self - .helper - .id_fn - .call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))? - .extract(py)?; + let ty_ty_id: u64 = + self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Ok((primitives.int32, true))) @@ -243,7 +244,7 @@ impl InnerResolver { Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.exception { Ok(Ok((primitives.exception, true))) - }else if ty_id == self.primitive_ids.list { + } else if ty_id == self.primitive_ids.list { // do not handle type var param and concrete check here let var = unifier.get_dummy_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); @@ -253,28 +254,21 @@ impl InnerResolver { Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { let def = defs[def_id.0].read(); - if let TopLevelDef::Class { - object_id, - type_vars, - fields, - methods, - .. - } = &*def - { + if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { // do not handle type var param and concrete check here, and no subst Ok(Ok({ let ty = TypeEnum::TObj { obj_id: *object_id, params: type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { - (*id, *x) - } else { - unreachable!() - } - }) - .collect(), + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + (*id, *x) + } else { + unreachable!() + } + }) + .collect(), fields: { let mut res = methods .iter() @@ -320,7 +314,8 @@ impl InnerResolver { } result }; - let res = unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; + let res = + unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; Ok(Ok((res, true))) } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 @@ -352,7 +347,7 @@ impl InnerResolver { }; if !unifier.is_concrete(ty.0, &[]) && !ty.1 { return Ok(Err( - "type list should take concrete parameters in typevar range".into() + "type list should take concrete parameters in typevar range".into(), )); } Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) @@ -417,10 +412,7 @@ impl InnerResolver { .map(|((id, _), ty)| (*id, *ty)) .collect::>() }; - Ok(Ok(( - unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), - true, - ))) + Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) } TypeEnum::TVirtual { .. } => { if args.len() == 1 { @@ -452,17 +444,19 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.virtual_id { Ok(Ok(( { - let ty = TypeEnum::TVirtual { - ty: unifier.get_dummy_var().0, - }; + let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().0 }; unifier.add_ty(ty) }, false, ))) } else { - let str_fn = pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); + let str_fn = + pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); - Ok(Err(format!("{} is not supported in nac3 (did you forgot to put @nac3 annotation?)", str_repr))) + Ok(Err(format!( + "{} is not supported in nac3 (did you forgot to put @nac3 annotation?)", + str_repr + ))) } } @@ -483,13 +477,8 @@ impl InnerResolver { self.primitive_ids.generic_alias.0, self.primitive_ids.generic_alias.1, ] - .contains( - &self - .helper - .id_fn - .call1(py, (ty.clone(),))? - .extract::(py)?, - ) { + .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) + { obj } else { ty.as_ref(py) @@ -518,9 +507,12 @@ impl InnerResolver { self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; match actual_ty { Ok(t) => match unifier.unify(*ty, t) { - Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList{ ty: *ty }))), - Err(e) => Ok(Err(format!("type error ({}) for the list", e.to_display(unifier).to_string()))), - } + Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList { ty: *ty }))), + Err(e) => Ok(Err(format!( + "type error ({}) for the list", + e.to_display(unifier).to_string() + ))), + }, Err(e) => Ok(Err(e)), } } @@ -553,18 +545,23 @@ impl InnerResolver { continue; } else { let field_data = obj.getattr(&name)?; - let ty = match self - .get_obj_type(py, field_data, unifier, defs, primitives)? { + let ty = + match self.get_obj_type(py, field_data, unifier, defs, primitives)? { Ok(t) => t, - Err(e) => return Ok(Err(format!( - "error when getting type of field `{}` ({})", name, e - ))), + Err(e) => { + return Ok(Err(format!( + "error when getting type of field `{}` ({})", + name, e + ))) + } }; let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); if let Err(e) = unifier.unify(ty, field_ty) { // field type mismatch return Ok(Err(format!( - "error when getting type of field `{}` ({})", name, e.to_display(unifier).to_string() + "error when getting type of field `{}` ({})", + name, + e.to_display(unifier).to_string() ))); } } @@ -575,11 +572,7 @@ impl InnerResolver { return Ok(Err("object is not of concrete type".into())); } } - return Ok(Ok( - unifier - .subst(extracted_ty, &var_map) - .unwrap_or(extracted_ty), - )); + return Ok(Ok(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); } _ => Ok(Ok(extracted_ty)), }; @@ -592,37 +585,24 @@ impl InnerResolver { ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut dyn CodeGenerator, ) -> PyResult>> { - let ty_id: u64 = self - .helper - .id_fn - .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? - .extract(py)?; + let ty_id: u64 = + self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { let val: i32 = obj.extract()?; - self.id_to_primitive - .write() - .insert(id, PrimitiveValue::I32(val)); + self.id_to_primitive.write().insert(id, PrimitiveValue::I32(val)); Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) } else if ty_id == self.primitive_ids.int64 { let val: i64 = obj.extract()?; - self.id_to_primitive - .write() - .insert(id, PrimitiveValue::I64(val)); + self.id_to_primitive.write().insert(id, PrimitiveValue::I64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; - self.id_to_primitive - .write() - .insert(id, PrimitiveValue::Bool(val)); - Ok(Some( - ctx.ctx.bool_type().const_int(val as u64, false).into(), - )) + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.bool_type().const_int(val as u64, false).into())) } else if ty_id == self.primitive_ids.float { let val: f64 = obj.extract()?; - self.id_to_primitive - .write() - .insert(id, PrimitiveValue::F64(val)); + self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); Ok(Some(ctx.ctx.f64_type().const_float(val).into())) } else if ty_id == self.primitive_ids.list { let id_str = id.to_string(); @@ -647,16 +627,14 @@ impl InnerResolver { }; let ty = ctx.get_llvm_type(generator, ty); let size_t = generator.get_size_type(ctx.ctx); - let arr_ty = ctx.ctx.struct_type( - &[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], - false, - ); + let arr_ty = ctx + .ctx + .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false); { if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module - .add_global(arr_ty, Some(AddressSpace::Generic), &id_str) + ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { @@ -666,8 +644,7 @@ impl InnerResolver { let arr: Result>, _> = (0..len) .map(|i| { - obj.get_item(i) - .and_then(|elem| self.get_obj_value(py, elem, ctx, generator)) + obj.get_item(i).and_then(|elem| self.get_obj_value(py, elem, ctx, generator)) }) .collect(); let arr = arr?.unwrap(); @@ -678,34 +655,19 @@ impl InnerResolver { &(id_str.clone() + "_"), ); let arr: BasicValueEnum = if ty.is_int_type() { - let arr: Vec<_> = arr - .into_iter() - .map(BasicValueEnum::into_int_value) - .collect(); + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_int_value).collect(); ty.into_int_type().const_array(&arr) } else if ty.is_float_type() { - let arr: Vec<_> = arr - .into_iter() - .map(BasicValueEnum::into_float_value) - .collect(); + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_float_value).collect(); ty.into_float_type().const_array(&arr) } else if ty.is_array_type() { - let arr: Vec<_> = arr - .into_iter() - .map(BasicValueEnum::into_array_value) - .collect(); + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_array_value).collect(); ty.into_array_type().const_array(&arr) } else if ty.is_struct_type() { - let arr: Vec<_> = arr - .into_iter() - .map(BasicValueEnum::into_struct_value) - .collect(); + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_struct_value).collect(); ty.into_struct_type().const_array(&arr) } else if ty.is_pointer_type() { - let arr: Vec<_> = arr - .into_iter() - .map(BasicValueEnum::into_pointer_value) - .collect(); + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_pointer_value).collect(); ty.into_pointer_type().const_array(&arr) } else { unreachable!() @@ -714,16 +676,11 @@ impl InnerResolver { arr_global.set_initializer(&arr); let val = arr_ty.const_named_struct(&[ - arr_global - .as_pointer_value() - .const_cast(ty.ptr_type(AddressSpace::Generic)) - .into(), + arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(), size_t.const_int(len as u64, false).into(), ]); - let global = ctx - .module - .add_global(arr_ty, Some(AddressSpace::Generic), &id_str); + let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) @@ -754,8 +711,7 @@ impl InnerResolver { { if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module - .add_global(ty, Some(AddressSpace::Generic), &id_str) + ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { @@ -763,15 +719,11 @@ impl InnerResolver { } } - let val: Result>, _> = elements - .iter() - .map(|elem| self.get_obj_value(py, elem, ctx, generator)) - .collect(); + let val: Result>, _> = + elements.iter().map(|elem| self.get_obj_value(py, elem, ctx, generator)).collect(); let val = val?.unwrap(); let val = ctx.ctx.const_struct(&val, false); - let global = ctx - .module - .add_global(ty, Some(AddressSpace::Generic), &id_str); + let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { @@ -793,8 +745,7 @@ impl InnerResolver { { if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module - .add_global(ty, Some(AddressSpace::Generic), &id_str) + ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { @@ -802,10 +753,8 @@ impl InnerResolver { } } // should be classes - let definition = top_level_defs - .get(self.pyid_to_def.read().get(&ty_id).unwrap().0) - .unwrap() - .read(); + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read(); if let TopLevelDef::Class { fields, .. } = &*definition { let values: Result>, _> = fields .iter() @@ -816,9 +765,7 @@ impl InnerResolver { let values = values?; if let Some(values) = values { let val = ty.const_named_struct(&values); - let global = ctx - .module - .add_global(ty, Some(AddressSpace::Generic), &id_str); + let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { @@ -835,39 +782,32 @@ impl InnerResolver { py: Python, obj: &PyAny, ) -> PyResult> { - let ty_id: u64 = self - .helper - .id_fn - .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? - .extract(py)?; - Ok( - if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - let val: i32 = obj.extract()?; - Ok(SymbolValue::I32(val)) - } else if ty_id == self.primitive_ids.int64 { - let val: i64 = obj.extract()?; - Ok(SymbolValue::I64(val)) - } else if ty_id == self.primitive_ids.bool { - let val: bool = obj.extract()?; - Ok(SymbolValue::Bool(val)) - } else if ty_id == self.primitive_ids.float { - let val: f64 = obj.extract()?; - Ok(SymbolValue::Double(val)) - } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let elements: Result, String>, _> = elements - .iter() - .map(|elem| self.get_default_param_obj_value(py, elem)) - .collect(); - let elements = match elements? { - Ok(el) => el, - Err(err) => return Ok(Err(err)), - }; - Ok(SymbolValue::Tuple(elements)) - } else { - Err("only primitives values and tuple can be default parameter value".into()) - }, - ) + let ty_id: u64 = + self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; + Ok(if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(SymbolValue::I32(val)) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(SymbolValue::I64(val)) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.float { + let val: f64 = obj.extract()?; + Ok(SymbolValue::Double(val)) + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.cast_as()?; + let elements: Result, String>, _> = + elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect(); + let elements = match elements? { + Ok(el) => el, + Err(err) => return Ok(Err(err)), + }; + Ok(SymbolValue::Tuple(elements)) + } else { + Err("only primitives values and tuple can be default parameter value".into()) + }) } } @@ -882,12 +822,8 @@ impl SymbolResolver for Resolver { for (key, val) in members.iter() { let key: &str = key.extract()?; if key == id.to_string() { - sym_value = Some( - self.0 - .get_default_param_obj_value(py, val) - .unwrap() - .unwrap(), - ); + sym_value = + Some(self.0.get_default_param_obj_value(py, val).unwrap().unwrap()); break; } } @@ -992,7 +928,8 @@ impl SymbolResolver for Resolver { id_to_def.get(&id).cloned().ok_or_else(|| "".to_string()) } .or_else(|_| { - let py_id = self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?; + let py_id = + self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?; let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or(format!( "`{}` is not registered in nac3, did you forgot to add @nac3?", id @@ -1008,8 +945,9 @@ impl SymbolResolver for Resolver { *id } else { let id = Python::with_gil(|py| -> PyResult { - self.0.helper.store_str.call1(py, (s, ))?.extract(py) - }).unwrap(); + self.0.helper.store_str.call1(py, (s,))?.extract(py) + }) + .unwrap(); string_store.insert(s.into(), id); id } diff --git a/nac3artiq/src/timeline.rs b/nac3artiq/src/timeline.rs index b9c3ddc..f66fc58 100644 --- a/nac3artiq/src/timeline.rs +++ b/nac3artiq/src/timeline.rs @@ -1,5 +1,5 @@ -use nac3core::codegen::CodeGenContext; use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}; +use nac3core::codegen::CodeGenContext; pub trait TimeFns { fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx>; @@ -19,41 +19,23 @@ impl TimeFns for NowPinningTimeFns64 { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder.build_bitcast( - now, - i32_type.ptr_type(AddressSpace::Generic), - "now_hiptr" - ); + let now_hiptr = + ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr"); if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { let now_loptr = unsafe { - ctx.builder.build_gep( - now_hiptr, - &[i32_type.const_int(2, false)], - "now_gep", - ) + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep") }; - if let ( - BasicValueEnum::IntValue(now_hi), - BasicValueEnum::IntValue(now_lo) - ) = ( + if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( ctx.builder.build_load(now_hiptr, "now_hi"), - ctx.builder.build_load(now_loptr, "now_lo") + ctx.builder.build_load(now_loptr, "now_lo"), ) { - let zext_hi = ctx.builder.build_int_z_extend( - now_hi, - i64_type, - "now_zext_hi" - ); + let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi"); let shifted_hi = ctx.builder.build_left_shift( zext_hi, i64_type.const_int(32, false), - "now_shifted_zext_hi" - ); - let zext_lo = ctx.builder.build_int_z_extend( - now_lo, - i64_type, - "now_zext_lo" + "now_shifted_zext_hi", ); + let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo"); ctx.builder.build_or(shifted_hi, zext_lo, "now_or").into() } else { unreachable!(); @@ -69,8 +51,7 @@ impl TimeFns for NowPinningTimeFns64 { let i64_32 = i64_type.const_int(32, false); if let BasicValueEnum::IntValue(time) = t { let time_hi = ctx.builder.build_int_truncate( - ctx.builder - .build_right_shift(time, i64_32, false, "now_lshr"), + ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), i32_type, "now_trunc", ); @@ -86,11 +67,7 @@ impl TimeFns for NowPinningTimeFns64 { ); if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { let now_loptr = unsafe { - ctx.builder.build_gep( - now_hiptr, - &[i32_type.const_int(2, false)], - "now_gep", - ) + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep") }; ctx.builder .build_store(now_hiptr, time_hi) @@ -108,66 +85,54 @@ impl TimeFns for NowPinningTimeFns64 { } } - fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { + fn emit_delay_mu<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + dt: BasicValueEnum<'ctx>, + ) { let i64_type = ctx.ctx.i64_type(); let i32_type = ctx.ctx.i32_type(); let now = ctx .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder.build_bitcast( - now, - i32_type.ptr_type(AddressSpace::Generic), - "now_hiptr" - ); + let now_hiptr = + ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr"); if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { let now_loptr = unsafe { - ctx.builder.build_gep( - now_hiptr, - &[i32_type.const_int(2, false)], - "now_loptr", - ) + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_loptr") }; if let ( BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo), - BasicValueEnum::IntValue(dt) + BasicValueEnum::IntValue(dt), ) = ( ctx.builder.build_load(now_hiptr, "now_hi"), ctx.builder.build_load(now_loptr, "now_lo"), - dt + dt, ) { - let zext_hi = ctx.builder.build_int_z_extend( - now_hi, - i64_type, - "now_zext_hi" - ); + let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi"); let shifted_hi = ctx.builder.build_left_shift( zext_hi, i64_type.const_int(32, false), - "now_shifted_zext_hi" - ); - let zext_lo = ctx.builder.build_int_z_extend( - now_lo, - i64_type, - "now_zext_lo" + "now_shifted_zext_hi", ); + let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo"); let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now_or"); - + let time = ctx.builder.build_int_add(now_val, dt, "now_add"); let time_hi = ctx.builder.build_int_truncate( - ctx.builder - .build_right_shift( - time, - i64_type.const_int(32, false), - false, - "now_lshr" - ), + ctx.builder.build_right_shift( + time, + i64_type.const_int(32, false), + false, + "now_lshr", + ), i32_type, "now_trunc", ); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); - + ctx.builder .build_store(now_hiptr, time_hi) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) @@ -200,9 +165,7 @@ impl TimeFns for NowPinningTimeFns { if let BasicValueEnum::IntValue(now_raw) = now_raw { let i64_32 = i64_type.const_int(32, false); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); - let now_hi = ctx - .builder - .build_right_shift(now_raw, i64_32, false, "now_lshr"); + let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr"); ctx.builder.build_or(now_lo, now_hi, "now_or").into() } else { unreachable!(); @@ -215,8 +178,7 @@ impl TimeFns for NowPinningTimeFns { let i64_32 = i64_type.const_int(32, false); if let BasicValueEnum::IntValue(time) = t { let time_hi = ctx.builder.build_int_truncate( - ctx.builder - .build_right_shift(time, i64_32, false, "now_lshr"), + ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), i32_type, "now_trunc", ); @@ -232,11 +194,7 @@ impl TimeFns for NowPinningTimeFns { ); if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { let now_loptr = unsafe { - ctx.builder.build_gep( - now_hiptr, - &[i32_type.const_int(1, false)], - "now_gep", - ) + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep") }; ctx.builder .build_store(now_hiptr, time_hi) @@ -254,7 +212,11 @@ impl TimeFns for NowPinningTimeFns { } } - fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { + fn emit_delay_mu<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + dt: BasicValueEnum<'ctx>, + ) { let i32_type = ctx.ctx.i32_type(); let i64_type = ctx.ctx.i64_type(); let i64_32 = i64_type.const_int(32, false); @@ -263,18 +225,13 @@ impl TimeFns for NowPinningTimeFns { .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); - if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = - (now_raw, dt) - { + if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) { let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); - let now_hi = ctx - .builder - .build_right_shift(now_raw, i64_32, false, "now_lshr"); + let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr"); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or"); let time = ctx.builder.build_int_add(now_val, dt, "now_add"); let time_hi = ctx.builder.build_int_truncate( - ctx.builder - .build_right_shift(time, i64_32, false, "now_lshr"), + ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), i32_type, "now_trunc", ); @@ -286,11 +243,7 @@ impl TimeFns for NowPinningTimeFns { ); if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { let now_loptr = unsafe { - ctx.builder.build_gep( - now_hiptr, - &[i32_type.const_int(1, false)], - "now_gep", - ) + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep") }; ctx.builder .build_store(now_hiptr, time_hi) @@ -315,33 +268,36 @@ pub struct ExternTimeFns {} impl TimeFns for ExternTimeFns { fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> { - let now_mu = ctx - .module - .get_function("now_mu") - .unwrap_or_else(|| ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)); - ctx.builder - .build_call(now_mu, &[], "now_mu") - .try_as_basic_value() - .left() - .unwrap() + let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { + ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) + }); + ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap() } fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) { - let at_mu = ctx - .module - .get_function("at_mu") - .unwrap_or_else(|| ctx.module.add_function("at_mu", ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), None)); - ctx.builder - .build_call(at_mu, &[t.into()], "at_mu"); + let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| { + ctx.module.add_function( + "at_mu", + ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), + None, + ) + }); + ctx.builder.build_call(at_mu, &[t.into()], "at_mu"); } - fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>) { - let delay_mu = ctx - .module - .get_function("delay_mu") - .unwrap_or_else(|| ctx.module.add_function("delay_mu", ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), None)); - ctx.builder - .build_call(delay_mu, &[dt.into()], "delay_mu"); + fn emit_delay_mu<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + dt: BasicValueEnum<'ctx>, + ) { + let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { + ctx.module.add_function( + "delay_mu", + ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), + None, + ) + }); + ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu"); } } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 36acaaa..bd7573c 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -33,7 +33,7 @@ pub enum Primitive { None, Range, Str, - Exception + Exception, } #[derive(Debug)] @@ -162,10 +162,16 @@ impl ConcreteTypeStore { // here we should not have type vars, but some partial instantiated // class methods can still have uninstantiated type vars, so // filter out all the methods, as this will not affect codegen - if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) { + if let TypeEnum::TFunc(..) = &*unifier.get_ty(ty.0) { None } else { - Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))) + Some(( + *name, + ( + self.from_unifier_type(unifier, primitives, ty.0, cache), + ty.1, + ), + )) } }) .collect(), @@ -246,34 +252,27 @@ impl ConcreteTypeStore { .map(|(name, cty)| { (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) }) - .collect::>() - .into(), + .collect::>(), params: params .iter() .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) - .collect::>() - .into(), + .collect::>(), }, - ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc( - FunSignature { - args: args - .iter() - .map(|arg| FuncArg { - name: arg.name, - ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), - default_value: arg.default_value.clone(), - }) - .collect(), - ret: self.to_unifier_type(unifier, primitives, *ret, cache), - vars: vars - .iter() - .map(|(id, cty)| { - (*id, self.to_unifier_type(unifier, primitives, *cty, cache)) - }) - .collect::>(), - } - .into(), - ), + ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { + args: args + .iter() + .map(|arg| FuncArg { + name: arg.name, + ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), + default_value: arg.default_value.clone(), + }) + .collect(), + ret: self.to_unifier_type(unifier, primitives, *ret, cache), + vars: vars + .iter() + .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) + .collect::>(), + }), }; let result = unifier.add_ty(result); if let Some(ty) = cache.get(&cty).unwrap() { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 11b8554..24467da 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,9 +3,9 @@ use std::{collections::HashMap, convert::TryInto, iter::once}; use crate::{ codegen::{ concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, - stmt::gen_raise, get_llvm_type, irrt::*, + stmt::gen_raise, CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, @@ -13,12 +13,14 @@ use crate::{ typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; use inkwell::{ - AddressSpace, types::{BasicType, BasicTypeEnum}, - values::{BasicValueEnum, FunctionValue, IntValue, PointerValue} + values::{BasicValueEnum, FunctionValue, IntValue, PointerValue}, + AddressSpace, }; use itertools::{chain, izip, zip, Itertools}; -use nac3parser::ast::{self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef}; +use nac3parser::ast::{ + self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, +}; use super::CodeGenerator; @@ -40,7 +42,14 @@ pub fn get_subst_key( vars.extend(fun_vars.iter()); let sorted = vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); sorted - .map(|id| unifier.internal_stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None)) + .map(|id| { + unifier.internal_stringify( + vars[id], + &mut |id| id.to_string(), + &mut |id| id.to_string(), + &mut None, + ) + }) .join(", ") } @@ -77,14 +86,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { index } - pub fn gen_symbol_val(&mut self, generator: &mut dyn CodeGenerator, val: &SymbolValue) -> BasicValueEnum<'ctx> { + pub fn gen_symbol_val( + &mut self, + generator: &mut dyn CodeGenerator, + val: &SymbolValue, + ) -> BasicValueEnum<'ctx> { match val { SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(), SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Str(v) => { - let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); + let str_ptr = + self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); ty.const_named_struct(&[str_ptr, size.into()]).into() @@ -125,7 +139,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ) } - pub fn gen_const(&mut self, generator: &mut dyn CodeGenerator, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { + pub fn gen_const( + &mut self, + generator: &mut dyn CodeGenerator, + value: &Constant, + ty: Type, + ) -> BasicValueEnum<'ctx> { match value { Constant::Bool(v) => { assert!(self.unifier.unioned(ty, self.primitives.bool)); @@ -163,10 +182,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { if let Some(v) = self.const_strings.get(v) { *v } else { - let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); + let str_ptr = + self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str); - let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); + let val = + ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); self.const_strings.insert(v.to_string(), val); val } @@ -262,12 +283,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &self, fun: FunctionValue<'ctx>, params: &[BasicValueEnum<'ctx>], - call_name: &str + call_name: &str, ) -> Option> { if let Some(target) = self.unwind_target { let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name)); - let result = self.builder.build_invoke(fun, params, then_block, target, call_name).try_as_basic_value().left(); + let result = self + .builder + .build_invoke(fun, params, then_block, target, call_name) + .try_as_basic_value() + .left(); self.builder.position_at_end(then_block); result } else { @@ -279,7 +304,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { pub fn gen_string>( &mut self, generator: &mut G, - s: S + s: S, ) -> BasicValueEnum<'ctx> { self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) } @@ -290,7 +315,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { name: &str, msg: BasicValueEnum<'ctx>, params: [Option>; 3], - loc: Location + loc: Location, ) { let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); @@ -302,13 +327,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let id = self.resolver.get_string_id(name); self.builder.build_store(id_ptr, int32.const_int(id as u64, false)); let ptr = self.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(5, false)], "exn.msg"); + zelf, + &[zero, int32.const_int(5, false)], + "exn.msg", + ); self.builder.build_store(ptr, msg); let i64_zero = self.ctx.i64_type().const_zero(); for (i, attr_ind) in [6, 7, 8].iter().enumerate() { let ptr = self.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(*attr_ind, false)], "exn.param"); - let val = params[i].map_or(i64_zero, |v| self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext")); + zelf, + &[zero, int32.const_int(*attr_ind, false)], + "exn.param", + ); + let val = params[i].map_or(i64_zero, |v| { + self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext") + }); self.builder.build_store(ptr, val); } } @@ -322,19 +355,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { err_name: &str, err_msg: &str, params: [Option>; 3], - loc: Location + loc: Location, ) { let i1 = self.ctx.bool_type(); let i1_true = i1.const_all_ones(); let expect_fun = self.module.get_function("llvm.expect.i1").unwrap_or_else(|| { - self.module.add_function("llvm.expect", i1.fn_type(&[i1.into(), i1.into()], false), None) + self.module.add_function( + "llvm.expect", + i1.fn_type(&[i1.into(), i1.into()], false), + None, + ) }); // we assume that the condition is most probably true, so the normal path is the most // probable path // even if this assumption is violated, it does not matter as exception unwinding is // slow anyway... - let cond = self.builder.build_call(expect_fun, &[cond.into(), i1_true.into()], "expect") - .try_as_basic_value().left().unwrap().into_int_value(); + let cond = self + .builder + .build_call(expect_fun, &[cond.into(), i1_true.into()], "expect") + .try_as_basic_value() + .left() + .unwrap() + .into_int_value(); let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_block = self.ctx.append_basic_block(current_fun, "succ"); let exn_block = self.ctx.append_basic_block(current_fun, "fail"); @@ -400,7 +442,7 @@ pub fn gen_func_instance<'ctx, 'a>( return Ok(sym.clone()); } let symbol = format!("{}.{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key.clone(), symbol.clone()); + instance_to_symbol.insert(key, symbol.clone()); let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id)); let instance = instance_to_stmt.get(&key).unwrap(); @@ -484,7 +526,10 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( } // default value handling for k in keys.into_iter() { - mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into()); + mapping.insert( + k.name, + ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into(), + ); } // reorder the parameters let mut real_params = @@ -821,7 +866,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( // we should use memcpy for that instead of generating thousands of stores let elements = elts .iter() - .map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) + .map(|x| { + generator + .gen_expr(ctx, x) + .map(|v| v.unwrap().to_basic_value_enum(ctx, generator)) + }) .collect::, _>>()?; let ty = if elements.is_empty() { if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) { @@ -850,7 +899,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ExprKind::Tuple { elts, .. } => { let element_val = elts .iter() - .map(|x| generator.gen_expr(ctx, x).map(|v| v.unwrap().to_basic_value_enum(ctx, generator))) + .map(|x| { + generator + .gen_expr(ctx, x) + .map(|v| v.unwrap().to_basic_value_enum(ctx, generator)) + }) .collect::, _>>()?; let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let tuple_ty = ctx.ctx.struct_type(&element_ty, false); @@ -935,7 +988,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); - let val = generator.gen_expr(ctx, operand)?.unwrap().to_basic_value_enum(ctx, generator); + let val = + generator.gen_expr(ctx, operand)?.unwrap().to_basic_value_enum(ctx, generator); if ty == ctx.primitives.bool { let val = val.into_int_value(); match op { @@ -1074,8 +1128,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.as_basic_value().into() } ExprKind::Call { func, args, keywords } => { - let mut params = - args.iter().map(|arg| Ok((None, generator.gen_expr(ctx, arg)?.unwrap())) as Result<_, String>) + let mut params = args + .iter() + .map(|arg| Ok((None, generator.gen_expr(ctx, arg)?.unwrap())) as Result<_, String>) .collect::, _>>()?; let kw_iter = keywords.iter().map(|kw| { Ok(( @@ -1101,7 +1156,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( match &func.node { ExprKind::Name { id, .. } => { // TODO: handle primitive casts and function pointers - let fun = ctx.resolver.get_identifier_def(*id).map_err(|e| format!("{} (at {})", e, func.location))?; + let fun = ctx + .resolver + .get_identifier_def(*id) + .map_err(|e| format!("{} (at {})", e, func.location))?; return Ok(generator .gen_call(ctx, None, (&signature, fun), params)? .map(|v| v.into())); @@ -1187,24 +1245,47 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ); res_array_ret.into() } else { - let len = ctx.build_gep_and_load(v, &[zero, int32.const_int(1, false)]) + let len = ctx + .build_gep_and_load(v, &[zero, int32.const_int(1, false)]) .into_int_value(); let raw_index = generator .gen_expr(ctx, slice)? .unwrap() .to_basic_value_enum(ctx, generator) .into_int_value(); - let raw_index = ctx.builder.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext"); + let raw_index = ctx.builder.build_int_s_extend( + raw_index, + generator.get_size_type(ctx.ctx), + "sext", + ); // handle negative index - let is_negative = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, raw_index, - generator.get_size_type(ctx.ctx).const_zero(), "is_neg"); + let is_negative = ctx.builder.build_int_compare( + inkwell::IntPredicate::SLT, + raw_index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ); let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); - let index = ctx.builder.build_select(is_negative, adjusted, raw_index, "index").into_int_value(); + let index = ctx + .builder + .build_select(is_negative, adjusted, raw_index, "index") + .into_int_value(); // unsigned less than is enough, because negative index after adjustment is // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder.build_int_compare(inkwell::IntPredicate::ULT, index, len, "inbound"); - ctx.make_assert(generator, bound_check, "0:IndexError", "index {0} out of bounds 0:{1}", - [Some(raw_index), Some(len), None], expr.location); + let bound_check = ctx.builder.build_int_compare( + inkwell::IntPredicate::ULT, + index, + len, + "inbound", + ); + ctx.make_assert( + generator, + bound_check, + "0:IndexError", + "index {0} out of bounds 0:{1}", + [Some(raw_index), Some(len), None], + expr.location, + ); ctx.build_gep_and_load(arr_ptr, &[index]) } } else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 25d2a27..4fcbfdb 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -118,8 +118,11 @@ pub trait CodeGenerator { /// Generate code for a while expression. /// Return true if the while loop must early return - fn gen_while<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) - -> Result<(), String> + fn gen_while<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> Result<(), String> where Self: Sized, { @@ -128,8 +131,11 @@ pub trait CodeGenerator { /// Generate code for a while expression. /// Return true if the while loop must early return - fn gen_for<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) - -> Result<(), String> + fn gen_for<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> Result<(), String> where Self: Sized, { @@ -138,16 +144,22 @@ pub trait CodeGenerator { /// Generate code for an if expression. /// Return true if the statement must early return - fn gen_if<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) - -> Result<(), String> + fn gen_if<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> Result<(), String> where Self: Sized, { gen_if(self, ctx, stmt) } - fn gen_with<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) - -> Result<(), String> + fn gen_with<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> Result<(), String> where Self: Sized, { @@ -156,8 +168,11 @@ pub trait CodeGenerator { /// Generate code for a statement /// Return true if the statement must early return - fn gen_stmt<'ctx, 'a>(&mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>) - -> Result<(), String> + fn gen_stmt<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> Result<(), String> where Self: Sized, { diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index cc03bcf..bced56d 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -215,7 +215,8 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>( }); let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator); - Ok(ctx.builder + Ok(ctx + .builder .build_call(func, &[i.into(), length.into()], "bounded_ind") .try_as_basic_value() .left() diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index c34f0fb..d83b22b 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -30,8 +30,8 @@ use std::thread; pub mod concrete_type; pub mod expr; mod generator; -pub mod stmt; pub mod irrt; +pub mod stmt; #[cfg(test)] mod test; @@ -274,12 +274,22 @@ fn get_llvm_type<'ctx>( // a struct with fields in the order of declaration let top_level_defs = top_level.definitions.read(); let definition = top_level_defs.get(obj_id.0).unwrap(); - let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } = &*definition.read() + let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } = + &*definition.read() { let struct_type = ctx.opaque_struct_type(&name.to_string()); let fields = fields_list .iter() - .map(|f| get_llvm_type(ctx, generator, unifier, top_level, type_cache, fields[&f.0].0)) + .map(|f| { + get_llvm_type( + ctx, + generator, + unifier, + top_level, + type_cache, + fields[&f.0].0, + ) + }) .collect_vec(); struct_type.set_body(&fields, false); struct_type.ptr_type(AddressSpace::Generic).into() @@ -298,9 +308,12 @@ fn get_llvm_type<'ctx>( } TList { ty } => { // a struct with an integer and a pointer to an array - let element_type = get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty); - let fields = - [element_type.ptr_type(AddressSpace::Generic).into(), generator.get_size_type(ctx).into()]; + let element_type = + get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty); + let fields = [ + element_type.ptr_type(AddressSpace::Generic).into(), + generator.get_size_type(ctx).into(), + ]; ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() } TVirtual { .. } => unimplemented!(), @@ -331,14 +344,17 @@ pub fn gen_func<'ctx, G: CodeGenerator>( // this should be unification between variables and concrete types // and should not cause any problem... let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache); - unifier.unify(*a, b).or_else(|err| { - if matches!(&*unifier.get_ty(*a), TypeEnum::TRigidVar { .. }) { - unifier.replace_rigid_var(*a, b); - Ok(()) - } else { - Err(err) - } - }).unwrap() + unifier + .unify(*a, b) + .or_else(|err| { + if matches!(&*unifier.get_ty(*a), TypeEnum::TRigidVar { .. }) { + unifier.replace_rigid_var(*a, b); + Ok(()) + } else { + Err(err) + } + }) + .unwrap() } // rebuild primitive store with unique representatives @@ -367,10 +383,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( str_type.set_body(&fields, false); str_type.into() }), - ( - primitives.range, - context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into(), - ), + (primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into()), ] .iter() .cloned() @@ -380,17 +393,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let int32 = context.i32_type().into(); let int64 = context.i64_type().into(); let str_ty = *type_cache.get(&primitives.str).unwrap(); - let fields = [ - int32, - str_ty, - int32, - int32, - str_ty, - str_ty, - int64, - int64, - int64 - ]; + let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64]; exception.set_body(&fields, false); exception.ptr_type(AddressSpace::Generic).into() }); @@ -414,15 +417,30 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let params = args .iter() .map(|arg| { - get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty).into() + get_llvm_type( + context, + generator, + &mut unifier, + top_level_ctx.as_ref(), + &mut type_cache, + arg.ty, + ) + .into() }) .collect_vec(); let fn_type = if unifier.unioned(ret, primitives.none) { context.void_type().fn_type(¶ms, false) } else { - get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret) - .fn_type(¶ms, false) + get_llvm_type( + context, + generator, + &mut unifier, + top_level_ctx.as_ref(), + &mut type_cache, + ret, + ) + .fn_type(¶ms, false) }; let symbol = &task.symbol_name; @@ -445,7 +463,14 @@ pub fn gen_func<'ctx, G: CodeGenerator>( for (n, arg) in args.iter().enumerate() { let param = fn_val.get_nth_param(n as u32).unwrap(); let alloca = builder.build_alloca( - get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty), + get_llvm_type( + context, + generator, + &mut unifier, + top_level_ctx.as_ref(), + &mut type_cache, + arg.ty, + ), &arg.name.to_string(), ); builder.build_store(alloca, param); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 247cd29..7793d54 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -7,7 +7,7 @@ use super::{ use crate::{ codegen::expr::gen_binop_expr, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{Type, TypeEnum, FunSignature} + typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -16,7 +16,9 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, IntPredicate::EQ, }; -use nac3parser::ast::{ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant}; +use nac3parser::ast::{ + Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, +}; use std::convert::TryFrom; pub fn gen_var<'ctx, 'a>( @@ -40,12 +42,16 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples Ok(match &pattern.node { - ExprKind::Name { id, .. } => ctx.var_assignment.get(id).map(|v| Ok(v.0) as Result<_, String>).unwrap_or_else(|| { - let ptr_ty = ctx.get_llvm_type(generator, pattern.custom.unwrap()); - let ptr = generator.gen_var_alloc(ctx, ptr_ty)?; - ctx.var_assignment.insert(*id, (ptr, None, 0)); - Ok(ptr) - })?, + ExprKind::Name { id, .. } => { + ctx.var_assignment.get(id).map(|v| Ok(v.0) as Result<_, String>).unwrap_or_else( + || { + let ptr_ty = ctx.get_llvm_type(generator, pattern.custom.unwrap()); + let ptr = generator.gen_var_alloc(ctx, ptr_ty)?; + ctx.var_assignment.insert(*id, (ptr, None, 0)); + Ok(ptr) + }, + )? + } ExprKind::Attribute { value, attr, .. } => { let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let val = generator.gen_expr(ctx, value)?.unwrap().to_basic_value_enum(ctx, generator); @@ -94,7 +100,7 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( target: &Expr>, value: ValueEnum<'ctx>, ) -> Result<(), String> { - Ok(match &target.node { + match &target.node { ExprKind::Tuple { elts, .. } => { if let BasicValueEnum::StructValue(v) = value.to_basic_value_enum(ctx, generator) { for (i, elt) in elts.iter().enumerate() { @@ -120,13 +126,12 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( let (start, end, step) = handle_slice_indices(lower, upper, step, ctx, generator, ls)?; let value = value.to_basic_value_enum(ctx, generator).into_pointer_value(); - let ty = if let TypeEnum::TList { ty } = - &*ctx.unifier.get_ty(target.custom.unwrap()) - { - ctx.get_llvm_type(generator, *ty) - } else { - unreachable!() - }; + let ty = + if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) { + ctx.get_llvm_type(generator, *ty) + } else { + unreachable!() + }; let src_ind = handle_slice_indices(&None, &None, &None, ctx, generator, value)?; list_slice_assignment( ctx, @@ -153,7 +158,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( let val = value.to_basic_value_enum(ctx, generator); ctx.builder.build_store(ptr, val); } - }) + }; + Ok(()) } pub fn gen_for<'ctx, 'a, G: CodeGenerator>( @@ -420,10 +426,10 @@ pub fn get_builtins<'ctx, 'a, G: CodeGenerator>( ) -> FunctionValue<'ctx> { ctx.module.get_function(symbol).unwrap_or_else(|| { let ty = match symbol { - "__artiq_raise" => ctx.ctx.void_type().fn_type( - &[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], - false, - ), + "__artiq_raise" => ctx + .ctx + .void_type() + .fn_type(&[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], false), "__artiq_resume" => ctx.ctx.void_type().fn_type(&[], false), "__artiq_end_catch" => ctx.ctx.void_type().fn_type(&[], false), _ => unimplemented!(), @@ -444,7 +450,7 @@ pub fn exn_constructor<'ctx, 'a>( obj: Option<(Type, ValueEnum<'ctx>)>, _fun: (&FunSignature, DefinitionId), mut args: Vec<(Option, ValueEnum<'ctx>)>, - generator: &mut dyn CodeGenerator + generator: &mut dyn CodeGenerator, ) -> Result>, String> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator).into_pointer_value(); @@ -459,19 +465,16 @@ pub fn exn_constructor<'ctx, 'a>( }; let defs = ctx.top_level.definitions.read(); let def = defs[zelf_id].read(); - let zelf_name = if let TopLevelDef::Class { name, .. } = &*def { - *name - } else { - unreachable!() - }; + let zelf_name = + if let TopLevelDef::Class { name, .. } = &*def { *name } else { unreachable!() }; let exception_name = format!("0:{}", zelf_name); unsafe { let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); let id = ctx.resolver.get_string_id(&exception_name); ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)); let empty_string = ctx.gen_const(generator, &Constant::Str("".into()), ctx.primitives.str); - let ptr = ctx.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(5, false)], "exn.msg"); + let ptr = + ctx.builder.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg"); let msg = if !args.is_empty() { args.remove(0).1.to_basic_value_enum(ctx, generator) } else { @@ -485,19 +488,28 @@ pub fn exn_constructor<'ctx, 'a>( ctx.ctx.i64_type().const_zero().into() }; let ptr = ctx.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(*i, false)], "exn.param"); + zelf, + &[zero, int32.const_int(*i, false)], + "exn.param", + ); ctx.builder.build_store(ptr, value); } // set file, func to empty string for i in [1, 4].iter() { let ptr = ctx.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(*i, false)], "exn.str"); + zelf, + &[zero, int32.const_int(*i, false)], + "exn.str", + ); ctx.builder.build_store(ptr, empty_string); } // set ints to zero for i in [2, 3].iter() { let ptr = ctx.builder.build_in_bounds_gep( - zelf, &[zero, int32.const_int(*i, false)], "exn.ints"); + zelf, + &[zero, int32.const_int(*i, false)], + "exn.ints", + ); ctx.builder.build_store(ptr, zero); } } @@ -515,17 +527,33 @@ pub fn gen_raise<'ctx, 'a, G: CodeGenerator>( let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let exception = exception.into_pointer_value(); - let file_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr"); + let file_ptr = ctx.builder.build_in_bounds_gep( + exception, + &[zero, int32.const_int(1, false)], + "file_ptr", + ); let filename = ctx.gen_string(generator, loc.file.0); ctx.builder.build_store(file_ptr, filename); - let row_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr"); + let row_ptr = ctx.builder.build_in_bounds_gep( + exception, + &[zero, int32.const_int(2, false)], + "row_ptr", + ); ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)); - let col_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr"); + let col_ptr = ctx.builder.build_in_bounds_gep( + exception, + &[zero, int32.const_int(3, false)], + "col_ptr", + ); ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)); let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); - let name_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr"); + let name_ptr = ctx.builder.build_in_bounds_gep( + exception, + &[zero, int32.const_int(4, false)], + "name_ptr", + ); ctx.builder.build_store(name_ptr, fun_name); } @@ -599,7 +627,11 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( for handler_node in handlers.iter() { let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; // none or Exception - if type_.is_none() || ctx.unifier.unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) { + if type_.is_none() + || ctx + .unifier + .unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) + { clauses.push(None); found_catch_all = true; break; @@ -928,7 +960,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { if let Some(exc) = exc { - let exc = generator.gen_expr(ctx, exc)?.unwrap().to_basic_value_enum(ctx, generator); + let exc = + generator.gen_expr(ctx, exc)?.unwrap().to_basic_value_enum(ctx, generator); gen_raise(generator, ctx, Some(&exc), stmt.location); } else { gen_raise(generator, ctx, None, stmt.location); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ffcf330..bb96918 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -34,7 +34,10 @@ impl Resolver { } impl SymbolResolver for Resolver { - fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + fn get_default_param_value( + &self, + _: &nac3parser::ast::Expr, + ) -> Option { unimplemented!() } @@ -57,7 +60,11 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Result { - self.id_to_def.read().get(&id).cloned().ok_or_else(|| format!("cannot find symbol `{}`", id)) + self.id_to_def + .read() + .get(&id) + .cloned() + .ok_or_else(|| format!("cannot find symbol `{}`", id)) } fn get_string_id(&self, _: &str) -> i32 { @@ -118,7 +125,7 @@ fn test_primitives() { virtual_checks: &mut virtual_checks, calls: &mut calls, defined_identifiers: identifiers.clone(), - in_handler: false + in_handler: false, }; inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32); @@ -263,7 +270,7 @@ fn test_simple_call() { virtual_checks: &mut virtual_checks, calls: &mut calls, defined_identifiers: identifiers.clone(), - in_handler: false + in_handler: false, }; inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("foo".into(), fun_ty); diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index aed4843..d88200a 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,7 +1,8 @@ -use std::{collections::HashMap, fmt::Display}; use std::fmt::Debug; use std::sync::Arc; +use std::{collections::HashMap, fmt::Display}; +use crate::typecheck::typedef::TypeEnum; use crate::{ codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}, @@ -13,7 +14,6 @@ use crate::{ typedef::{Type, Unifier}, }, }; -use crate::typecheck::typedef::TypeEnum; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use itertools::{chain, izip}; use nac3parser::ast::{Expr, Location, StrRef}; @@ -36,11 +36,13 @@ impl Display for SymbolValue { SymbolValue::I64(i) => write!(f, "int64({})", i), SymbolValue::Str(s) => write!(f, "\"{}\"", s), SymbolValue::Double(d) => write!(f, "{}", d), - SymbolValue::Bool(b) => if *b { - write!(f, "True") - } else { - write!(f, "False") - }, + SymbolValue::Bool(b) => { + if *b { + write!(f, "True") + } else { + write!(f, "False") + } + } SymbolValue::Tuple(t) => { write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::>().join(", ")) } @@ -203,7 +205,8 @@ pub fn parse_type_annotation( let fields = chain( fields.iter().map(|(k, v, m)| (*k, (*v, *m))), methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ).collect(); + ) + .collect(); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, @@ -214,7 +217,8 @@ pub fn parse_type_annotation( } } Err(e) => { - let ty = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) + let ty = resolver + .get_symbol_type(unifier, top_level_defs, primitives, *id) .map_err(|_| format!("Unknown type annotation at {}: {}", loc, e))?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) @@ -256,8 +260,7 @@ pub fn parse_type_annotation( vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?] }; - let obj_id = resolver - .get_identifier_def(*id)?; + let obj_id = resolver.get_identifier_def(*id)?; let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if types.len() != type_vars.len() { @@ -287,11 +290,7 @@ pub fn parse_type_annotation( let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); (*attr, (ty, false)) })); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: subst, - })) + Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) } else { Err("Cannot use function name as type".into()) } @@ -338,7 +337,7 @@ impl dyn SymbolResolver + Send + Sync { } }, &mut |id| format!("var{}", id), - &mut None + &mut None, ) } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index fbf3b2a..a48dd17 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,14 +1,13 @@ use super::*; use crate::{ - codegen::{expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor}, + codegen::{ + expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor, + }, symbol_resolver::SymbolValue, }; use inkwell::{FloatPredicate, IntPredicate}; -type BuiltinInfo = ( - Vec<(Arc>, Option)>, - &'static [&'static str] -); +type BuiltinInfo = (Vec<(Arc>, Option)>, &'static [&'static str]); pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let int32 = primitives.0.int32; @@ -17,7 +16,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; - let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean], Some("N".into()), None); + let num_ty = primitives.1.get_fresh_var_with_range( + &[int32, int64, float, boolean], + Some("N".into()), + None, + ); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ @@ -34,65 +37,83 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let div_by_zero = primitives.1.add_ty(TypeEnum::TObj { obj_id: DefinitionId(10), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), - params: Default::default() + params: Default::default(), }); let index_error = primitives.1.add_ty(TypeEnum::TObj { obj_id: DefinitionId(11), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), - params: Default::default() + params: Default::default(), }); let exn_cons_args = vec![ - FuncArg { name: "msg".into(), ty: string, - default_value: Some(SymbolValue::Str("".into()))}, - FuncArg { name: "param0".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, - FuncArg { name: "param1".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, - FuncArg { name: "param2".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + FuncArg { + name: "msg".into(), + ty: string, + default_value: Some(SymbolValue::Str("".into())), + }, + FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) }, + FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) }, + FuncArg { name: "param2".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) }, ]; let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: exn_cons_args.clone(), ret: div_by_zero, - vars: Default::default() + vars: Default::default(), })); let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: exn_cons_args, ret: index_error, - vars: Default::default() + vars: Default::default(), })); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 0, - None, - "int32".into(), - None, - None, + 0, + None, + "int32".into(), + None, + None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 1, - None, - "int64".into(), - None, - None, + 1, + None, + "int64".into(), + None, + None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 2, - None, - "float".into(), - None, - None, + 2, + None, + "float".into(), + None, + None, ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(3, None, "bool".into(), None, None))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(4, None, "none".into(), None, None))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - 5, - None, - "range".into(), - None, - None, + 3, + None, + "bool".into(), + None, + None, + ))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( + 4, + None, + "none".into(), + None, + None, + ))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( + 5, + None, + "range".into(), + None, + None, + ))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( + 6, + None, + "str".into(), + None, + None, ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(6, None, "str".into(), None, None))), Arc::new(RwLock::new(TopLevelDef::Class { name: "Exception".into(), object_id: DefinitionId(7), @@ -134,7 +155,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { methods: vec![("__init__".into(), div_by_zero_signature, DefinitionId(8))], ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(10), params: Default::default() }, - TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } + TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }, ], constructor: Some(div_by_zero_signature), resolver: None, @@ -148,7 +169,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { methods: vec![("__init__".into(), index_error_signature, DefinitionId(9))], ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(11), params: Default::default() }, - TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } + TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }, ], constructor: Some(index_error_signature), resolver: None, @@ -167,49 +188,49 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let boolean = ctx.primitives.bool; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - Ok(if ctx.unifier.unioned(arg_ty, boolean) { - Some( - ctx.builder - .build_int_z_extend( - arg.into_int_value(), - ctx.ctx.i32_type(), - "zext", - ) - .into(), - ) - } else if ctx.unifier.unioned(arg_ty, int32) { - Some(arg) - } else if ctx.unifier.unioned(arg_ty, int64) { - Some( - ctx.builder - .build_int_truncate( - arg.into_int_value(), - ctx.ctx.i32_type(), - "trunc", - ) - .into(), - ) - } else if ctx.unifier.unioned(arg_ty, float) { - let val = ctx - .builder - .build_float_to_signed_int( - arg.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(); - Some(val) - } else { - unreachable!() - }) - }, + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok(if ctx.unifier.unioned(arg_ty, boolean) { + Some( + ctx.builder + .build_int_z_extend( + arg.into_int_value(), + ctx.ctx.i32_type(), + "zext", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, int32) { + Some(arg) + } else if ctx.unifier.unioned(arg_ty, int64) { + Some( + ctx.builder + .build_int_truncate( + arg.into_int_value(), + ctx.ctx.i32_type(), + "trunc", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, float) { + let val = ctx + .builder + .build_float_to_signed_int( + arg.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(); + Some(val) + } else { + unreachable!() + }) + }, )))), loc: None, })), @@ -226,41 +247,43 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let boolean = ctx.primitives.bool; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - Ok(if ctx.unifier.unioned(arg_ty, boolean) - || ctx.unifier.unioned(arg_ty, int32) - { - Some( - ctx.builder - .build_int_z_extend( - arg.into_int_value(), - ctx.ctx.i64_type(), - "zext", - ) - .into(), + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok( + if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) + { + Some( + ctx.builder + .build_int_z_extend( + arg.into_int_value(), + ctx.ctx.i64_type(), + "zext", ) - } else if ctx.unifier.unioned(arg_ty, int64) { - Some(arg) - } else if ctx.unifier.unioned(arg_ty, float) { - let val = ctx - .builder - .build_float_to_signed_int( - arg.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(); - Some(val) - } else { - unreachable!() - }) - }, + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, int64) { + Some(arg) + } else if ctx.unifier.unioned(arg_ty, float) { + let val = ctx + .builder + .build_float_to_signed_int( + arg.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(); + Some(val) + } else { + unreachable!() + }, + ) + }, )))), loc: None, })), @@ -277,29 +300,31 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let boolean = ctx.primitives.bool; - let float = ctx.primitives.float; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - Ok(if ctx.unifier.unioned(arg_ty, boolean) - || ctx.unifier.unioned(arg_ty, int32) - || ctx.unifier.unioned(arg_ty, int64) - { - let arg = arg.into_int_value(); - let val = ctx - .builder - .build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp") - .into(); - Some(val) - } else if ctx.unifier.unioned(arg_ty, float) { - Some(arg) - } else { - unreachable!() - }) - }, + |ctx, _, fun, args, generator| { + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let boolean = ctx.primitives.bool; + let float = ctx.primitives.float; + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + Ok( + if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) + || ctx.unifier.unioned(arg_ty, int64) + { + let arg = arg.into_int_value(); + let val = ctx + .builder + .build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp") + .into(); + Some(val) + } else if ctx.unifier.unioned(arg_ty, float) { + Some(arg) + } else { + unreachable!() + }, + ) + }, )))), loc: None, })), @@ -315,30 +340,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let round_intrinsic = - ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.round.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(round_intrinsic, &[arg.into()], "round") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let round_intrinsic = + ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(round_intrinsic, &[arg.into()], "round") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -353,30 +380,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let round_intrinsic = - ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.round.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(round_intrinsic, &[arg.into()], "round") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let round_intrinsic = + ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(round_intrinsic, &[arg.into()], "round") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -404,55 +433,57 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let mut start = None; - let mut stop = None; - let mut step = None; - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - for (i, arg) in args.iter().enumerate() { - if arg.0 == Some("start".into()) { - start = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); - } else if arg.0 == Some("stop".into()) { - stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); - } else if arg.0 == Some("step".into()) { - step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); - } else if i == 0 { - start = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); - } else if i == 1 { - stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); - } else if i == 2 { - step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let mut start = None; + let mut stop = None; + let mut step = None; + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + for (i, arg) in args.iter().enumerate() { + if arg.0 == Some("start".into()) { + start = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } else if arg.0 == Some("stop".into()) { + stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } else if arg.0 == Some("step".into()) { + step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } else if i == 0 { + start = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } else if i == 1 { + stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } else if i == 2 { + step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)); + } } - } - // TODO: error when step == 0 - let step = step.unwrap_or_else(|| int32.const_int(1, false).into()); - let stop = stop.unwrap_or_else(|| { - let v = start.unwrap(); - start = None; - v - }); - let start = start.unwrap_or_else(|| int32.const_zero().into()); - let ty = int32.array_type(3); - let ptr = ctx.builder.build_alloca(ty, "range"); - unsafe { - let a = ctx.builder.build_in_bounds_gep(ptr, &[zero, zero], "start"); - let b = ctx.builder.build_in_bounds_gep( - ptr, - &[zero, int32.const_int(1, false)], - "end", - ); - let c = ctx.builder.build_in_bounds_gep( - ptr, - &[zero, int32.const_int(2, false)], - "step", - ); - ctx.builder.build_store(a, start); - ctx.builder.build_store(b, stop); - ctx.builder.build_store(c, step); - } - Ok(Some(ptr.into())) - })))), + // TODO: error when step == 0 + let step = step.unwrap_or_else(|| int32.const_int(1, false).into()); + let stop = stop.unwrap_or_else(|| { + let v = start.unwrap(); + start = None; + v + }); + let start = start.unwrap_or_else(|| int32.const_zero().into()); + let ty = int32.array_type(3); + let ptr = ctx.builder.build_alloca(ty, "range"); + unsafe { + let a = ctx.builder.build_in_bounds_gep(ptr, &[zero, zero], "start"); + let b = ctx.builder.build_in_bounds_gep( + ptr, + &[zero, int32.const_int(1, false)], + "end", + ); + let c = ctx.builder.build_in_bounds_gep( + ptr, + &[zero, int32.const_int(2, false)], + "step", + ); + ctx.builder.build_store(a, start); + ctx.builder.build_store(b, stop); + ctx.builder.build_store(c, step); + } + Ok(Some(ptr.into())) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -467,9 +498,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - Ok(Some(args[0].1.clone().to_basic_value_enum(ctx, generator))) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + Ok(Some(args[0].1.clone().to_basic_value_enum(ctx, generator))) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -495,28 +528,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(if ctx.unifier.unioned(arg_ty, boolean) { Some(arg) } else if ctx.unifier.unioned(arg_ty, int32) { - Some(ctx.builder.build_int_compare( - IntPredicate::NE, - ctx.ctx.i32_type().const_zero(), - arg.into_int_value(), - "bool", - ).into()) + Some( + ctx.builder + .build_int_compare( + IntPredicate::NE, + ctx.ctx.i32_type().const_zero(), + arg.into_int_value(), + "bool", + ) + .into(), + ) } else if ctx.unifier.unioned(arg_ty, int64) { - Some(ctx.builder.build_int_compare( - IntPredicate::NE, - ctx.ctx.i64_type().const_zero(), - arg.into_int_value(), - "bool", - ).into()) + Some( + ctx.builder + .build_int_compare( + IntPredicate::NE, + ctx.ctx.i64_type().const_zero(), + arg.into_int_value(), + "bool", + ) + .into(), + ) } else if ctx.unifier.unioned(arg_ty, float) { - let val = ctx.builder. - build_float_compare( + let val = ctx + .builder + .build_float_compare( // UEQ as bool(nan) is True FloatPredicate::UEQ, arg.into_float_value(), ctx.ctx.f64_type().const_zero(), - "bool" - ).into(); + "bool", + ) + .into(); Some(val) } else { unreachable!() @@ -537,30 +580,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let floor_intrinsic = - ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(floor_intrinsic, &[arg.into()], "floor") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let floor_intrinsic = + ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(floor_intrinsic, &[arg.into()], "floor") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -575,30 +620,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let floor_intrinsic = - ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(floor_intrinsic, &[arg.into()], "floor") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let floor_intrinsic = + ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(floor_intrinsic, &[arg.into()], "floor") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -613,30 +660,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let ceil_intrinsic = - ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(ceil_intrinsic, &[arg.into()], "ceil") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i32_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let ceil_intrinsic = + ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(ceil_intrinsic, &[arg.into()], "ceil") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i32_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Function { @@ -651,47 +700,51 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); - let ceil_intrinsic = - ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(ceil_intrinsic, &[arg.into()], "ceil") - .try_as_basic_value() - .left() - .unwrap(); - Ok(Some( - ctx.builder - .build_float_to_signed_int( - val.into_float_value(), - ctx.ctx.i64_type(), - "fptosi", - ) - .into(), - )) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + let ceil_intrinsic = + ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let float = ctx.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + let val = ctx + .builder + .build_call(ceil_intrinsic, &[arg.into()], "ceil") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some( + ctx.builder + .build_float_to_signed_int( + val.into_float_value(), + ctx.ctx.i64_type(), + "fptosi", + ) + .into(), + )) + }, + )))), loc: None, })), Arc::new(RwLock::new({ let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); - let arg_ty = primitives.1.get_fresh_var_with_range(&[list, primitives.0.range], Some("I".into()), None); + let arg_ty = primitives.1.get_fresh_var_with_range( + &[list, primitives.0.range], + Some("I".into()), + None, + ); TopLevelDef::Function { name: "len".into(), simple_name: "len".into(), signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "ls".into(), - ty: arg_ty.0, - default_value: None - }], + args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], ret: int32, - vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), + vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)] + .into_iter() + .collect(), })), var_id: vec![arg_ty.1], instance_to_symbol: Default::default(), @@ -709,7 +762,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { } else { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); - let len = ctx.build_gep_and_load(arg.into_pointer_value(), &[zero, int32.const_int(1, false)]).into_int_value(); + let len = ctx + .build_gep_and_load( + arg.into_pointer_value(), + &[zero, int32.const_int(1, false)], + ) + .into_int_value(); if len.get_type().get_bit_width() != 32 { Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) } else { @@ -720,7 +778,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, } - })) + })), ]; let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); @@ -742,6 +800,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "ceil", "ceil64", "len", - ] + ], ) } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index e5a2920..4927507 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,9 +1,9 @@ use nac3parser::ast::fold::Fold; use crate::{ - typecheck::type_inferencer::{FunctionData, Inferencer}, codegen::{expr::get_subst_key, stmt::exn_constructor}, symbol_resolver::SymbolValue, + typecheck::type_inferencer::{FunctionData, Inferencer}, }; use super::*; @@ -15,10 +15,7 @@ pub struct ComposerConfig { impl Default for ComposerConfig { fn default() -> Self { - ComposerConfig { - kernel_ann: None, - kernel_invariant_ann: "Invariant" - } + ComposerConfig { kernel_ann: None, kernel_invariant_ann: "Invariant" } } } @@ -52,7 +49,7 @@ impl TopLevelComposer { /// resolver can later figure out primitive type definitions when passed a primitive type name pub fn new( builtins: Vec<(StrRef, FunSignature, Arc)>, - core_config: ComposerConfig + core_config: ComposerConfig, ) -> (Self, HashMap, HashMap) { let mut primitives = Self::make_primitives(); let (mut definition_ast_list, builtin_name_list) = builtins::get_builtins(&mut primitives); @@ -89,7 +86,8 @@ impl TopLevelComposer { assert!(name == *simple_name); builtin_ty.insert(name, *signature); builtin_id.insert(name, DefinitionId(id)); - } else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def { + } else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def + { assert!(id == object_id.0); assert!(type_vars.is_empty()); if let Some(constructor) = constructor { @@ -377,7 +375,7 @@ impl TopLevelComposer { unreachable!("must be both class") } } else { - return Ok(()) + return Ok(()); } }; let class_resolver = class_resolver.as_ref().unwrap(); @@ -484,72 +482,75 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); let primitive_types = self.primitives_ty; - let mut get_direct_parents = |class_def: &Arc>, class_ast: &Option| { - let mut class_def = class_def.write(); - let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { - if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = - class_def.deref_mut() - { - if let Some(ast::Located { - node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast + let mut get_direct_parents = + |class_def: &Arc>, class_ast: &Option| { + let mut class_def = class_def.write(); + let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { + if let TopLevelDef::Class { + ancestors, resolver, object_id, type_vars, .. + } = class_def.deref_mut() { - (object_id, bases, ancestors, resolver, type_vars) + if let Some(ast::Located { + node: ast::StmtKind::ClassDef { bases, .. }, + .. + }) = class_ast + { + (object_id, bases, ancestors, resolver, type_vars) + } else { + unreachable!("must be both class") + } } else { - unreachable!("must be both class") + return Ok(()); } - } else { - return Ok(()); - } - }; - let class_resolver = class_resolver.as_ref().unwrap(); - let class_resolver = class_resolver.deref(); + }; + let class_resolver = class_resolver.as_ref().unwrap(); + let class_resolver = class_resolver.deref(); - let mut has_base = false; - for b in class_bases { - // type vars have already been handled, so skip on `Generic[...]` - if matches!( - &b.node, - ast::ExprKind::Subscript { value, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &"Generic".into() - ) - ) { - continue; - } + let mut has_base = false; + for b in class_bases { + // type vars have already been handled, so skip on `Generic[...]` + if matches!( + &b.node, + ast::ExprKind::Subscript { value, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &"Generic".into() + ) + ) { + continue; + } - if has_base { - return Err(format!( - "a class definition can only have at most one base class \ + if has_base { + return Err(format!( + "a class definition can only have at most one base class \ declaration and one generic declaration (at {})", - b.location - )); - } - has_base = true; + b.location + )); + } + has_base = true; - // the function parse_ast_to make sure that no type var occured in - // bast_ty if it is a CustomClassKind - let base_ty = parse_ast_to_type_annotation_kinds( - class_resolver, - &temp_def_list, - unifier, - &primitive_types, - b, - vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), - )?; + // the function parse_ast_to make sure that no type var occured in + // bast_ty if it is a CustomClassKind + let base_ty = parse_ast_to_type_annotation_kinds( + class_resolver, + &temp_def_list, + unifier, + &primitive_types, + b, + vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), + )?; - if let TypeAnnotation::CustomClass { .. } = &base_ty { - class_ancestors.push(base_ty); - } else { - return Err(format!( - "class base declaration can only be custom class (at {})", - b.location, - )); + if let TypeAnnotation::CustomClass { .. } = &base_ty { + class_ancestors.push(base_ty); + } else { + return Err(format!( + "class base declaration can only be custom class (at {})", + b.location, + )); + } } - } - Ok(()) - }; + Ok(()) + }; // first, only push direct parent into the list let mut errors = HashSet::new(); @@ -570,7 +571,7 @@ impl TopLevelComposer { if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() { (ancestors, *object_id) } else { - return Ok(()) + return Ok(()); } }; ancestors_store.insert( @@ -614,11 +615,17 @@ impl TopLevelComposer { .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); // special case classes that inherit from Exception - if class_ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { + if class_ancestors + .iter() + .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) + { // if inherited from Exception, the body should be a pass if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node { for stmt in body.iter() { - if matches!(stmt.node, ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }) { + if matches!( + stmt.node, + ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } + ) { return Err("Classes inherited from exception should have no custom fields/methods".into()); } } @@ -629,7 +636,9 @@ impl TopLevelComposer { } // deal with ancestor of Exception object - if let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *self.definition_ast_list[7].0.write() { + if let TopLevelDef::Class { name, ancestors, object_id, .. } = + &mut *self.definition_ast_list[7].0.write() + { assert_eq!(*name, "Exception".into()); ancestors.push(make_self_type_annotation(&[], *object_id)); } else { @@ -658,7 +667,7 @@ impl TopLevelComposer { unifier, primitives, &mut type_var_to_concrete_def, - (&self.keyword_list, &self.core_config) + (&self.keyword_list, &self.core_config), ) { errors.insert(e); } @@ -740,7 +749,7 @@ impl TopLevelComposer { x } else { // if let TopLevelDef::Function { name, .. } = `` - return Ok(()) + return Ok(()); }; if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = @@ -760,7 +769,9 @@ impl TopLevelComposer { // make sure no duplicate parameter let mut defined_paramter_name: HashSet<_> = HashSet::new(); for x in args.args.iter() { - if !defined_paramter_name.insert(x.node.arg) || keyword_list.contains(&x.node.arg) { + if !defined_paramter_name.insert(x.node.arg) + || keyword_list.contains(&x.node.arg) + { return Err(format!( "top level function must have unique parameter names \ and names should not be the same as the keywords (at {})", @@ -769,17 +780,21 @@ impl TopLevelComposer { } } - let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + let arg_with_default: Vec<( + &ast::Located>, + Option<&ast::Expr>, + )> = args .args .iter() .rev() - .zip(args - .defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)) - ).collect_vec(); + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); arg_with_default .iter() @@ -839,16 +854,21 @@ impl TopLevelComposer { default_value: match default { None => None, Some(default) => Some({ - let v = Self::parse_parameter_default_value(default, resolver)?; + let v = Self::parse_parameter_default_value( + default, resolver, + )?; Self::check_default_param_type( &v, &type_annotation, primitives_store, - unifier - ).map_err(|err| format!("{} (at {})", err, x.location))?; + unifier, + ) + .map_err( + |err| format!("{} (at {})", err, x.location), + )?; v - }) - } + }), + }, }) }) .collect::, _>>()? @@ -910,18 +930,20 @@ impl TopLevelComposer { .collect_vec() .as_slice() ); - let function_ty = unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } - )); - unifier - .unify(*dummy_ty, function_ty) - .map_err(|e| e.at(Some(function_ast.location)).to_display(unifier).to_string())?; + let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_types, + ret: return_ty, + vars: function_var_map, + })); + unifier.unify(*dummy_ty, function_ty).map_err(|e| { + e.at(Some(function_ast.location)).to_display(unifier).to_string() + })?; } else { unreachable!("must be both function"); } } else { // not top level function def, skip - return Ok(()) + return Ok(()); } Ok(()) }; @@ -931,7 +953,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")) + return Err(errors.iter().join("\n----------\n")); } Ok(()) } @@ -1003,35 +1025,46 @@ impl TopLevelComposer { let zelf: StrRef = "self".into(); for x in args.args.iter() { if !defined_paramter_name.insert(x.node.arg) - || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) { + || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) + { return Err(format!( "top level function must have unique parameter names \ and names should not be the same as the keywords (at {})", x.location - )) + )); } } if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) { - return Err(format!("__init__ method must have a `self` parameter (at {})", b.location)); + return Err(format!( + "__init__ method must have a `self` parameter (at {})", + b.location + )); } if !defined_paramter_name.contains(&zelf) { - return Err(format!("class method must have a `self` parameter (at {})", b.location)); + return Err(format!( + "class method must have a `self` parameter (at {})", + b.location + )); } let mut result = Vec::new(); - let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + let arg_with_default: Vec<( + &ast::Located>, + Option<&ast::Expr>, + )> = args .args .iter() .rev() - .zip(args - .defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)) - ).collect_vec(); + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); for (x, default) in arg_with_default.into_iter().rev() { let name = x.node.arg; @@ -1085,13 +1118,20 @@ impl TopLevelComposer { return Err(format!("`self` parameter cannot take default value (at {})", x.location)); } Some({ - let v = Self::parse_parameter_default_value(default, class_resolver)?; - Self::check_default_param_type(&v, &type_ann, primitives, unifier) - .map_err(|err| format!("{} (at {})", err, x.location))?; + let v = Self::parse_parameter_default_value( + default, + class_resolver, + )?; + Self::check_default_param_type( + &v, &type_ann, primitives, unifier, + ) + .map_err(|err| { + format!("{} (at {})", err, x.location) + })?; v }) } - } + }, }; // push the dummy type and the type annotation // into the list for later unification @@ -1162,14 +1202,17 @@ impl TopLevelComposer { } else { unreachable!() } - let method_type = unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: arg_types, ret: ret_type, vars: method_var_map } - .into(), - )); + let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_types, + ret: ret_type, + vars: method_var_map, + })); // unify now since function type is not in type annotation define // which should be fine since type within method_type will be subst later - unifier.unify(method_dummy_ty, method_type).map_err(|e| e.to_display(unifier).to_string())?; + unifier + .unify(method_dummy_ty, method_type) + .map_err(|e| e.to_display(unifier).to_string())?; } ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { if let ast::ExprKind::Name { id: attr, .. } = &target.node { @@ -1178,16 +1221,24 @@ impl TopLevelComposer { // handle Kernel[T], KernelInvariant[T] let (annotation, mutable) = match &annotation.node { - ast::ExprKind::Subscript { value, slice, .. } if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() - ) => (slice, false), - ast::ExprKind::Subscript { value, slice, .. } if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) - ) => (slice, true), + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() + ) => + { + (slice, false) + } + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + (slice, true) + } _ if core_config.kernel_ann.is_none() => (annotation, true), - _ => continue // ignore fields annotated otherwise + _ => continue, // ignore fields annotated otherwise }; class_fields_def.push((*attr, dummy_field_type, mutable)); @@ -1220,8 +1271,7 @@ impl TopLevelComposer { } else { return Err(format!( "same class fields `{}` defined twice (at {})", - attr, - target.location + attr, target.location )); } } else { @@ -1233,10 +1283,12 @@ impl TopLevelComposer { } ast::StmtKind::Pass { .. } => {} ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior - _ => return Err(format!( - "unsupported statement in class definition body (at {})", - b.location - )), + _ => { + return Err(format!( + "unsupported statement in class definition body (at {})", + b.location + )) + } } } Ok(()) @@ -1394,23 +1446,38 @@ impl TopLevelComposer { primitives_ty, &make_self_type_annotation(type_vars, *object_id), )?; - if ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { + if ancestors + .iter() + .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) + { // create constructor for these classes let string = primitives_ty.str; let int64 = primitives_ty.int64; let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { name: "msg".into(), ty: string, - default_value: Some(SymbolValue::Str("".into()))}, - FuncArg { name: "param0".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, - FuncArg { name: "param1".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, - FuncArg { name: "param2".into(), ty: int64, - default_value: Some(SymbolValue::I64(0))}, + FuncArg { + name: "msg".into(), + ty: string, + default_value: Some(SymbolValue::Str("".into())), + }, + FuncArg { + name: "param0".into(), + ty: int64, + default_value: Some(SymbolValue::I64(0)), + }, + FuncArg { + name: "param1".into(), + ty: int64, + default_value: Some(SymbolValue::I64(0)), + }, + FuncArg { + name: "param2".into(), + ty: int64, + default_value: Some(SymbolValue::I64(0)), + }, ], ret: self_type, - vars: Default::default() + vars: Default::default(), })); let cons_fun = TopLevelDef::Function { name: format!("{}.{}", class_name, "__init__"), @@ -1421,14 +1488,13 @@ impl TopLevelComposer { instance_to_stmt: Default::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))), - loc: None + loc: None, }; constructors.push((i, signature, definition_extension.len())); definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); - unifier - .unify(constructor.unwrap(), signature) - .map_err(|e| e.at(Some(ast.as_ref().unwrap().location)) - .to_display(unifier).to_string())?; + unifier.unify(constructor.unwrap(), signature).map_err(|e| { + e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() + })?; return Ok(()); } let mut init_id: Option = None; @@ -1439,7 +1505,9 @@ impl TopLevelComposer { for (name, func_sig, id) in methods { if *name == init_str_id { init_id = Some(*id); - if let TypeEnum::TFunc(FunSignature { args, vars, ..}) = unifier.get_ty(*func_sig).as_ref() { + if let TypeEnum::TFunc(FunSignature { args, vars, .. }) = + unifier.get_ty(*func_sig).as_ref() + { constructor_args.extend_from_slice(args); type_vars.extend(vars); } else { @@ -1449,17 +1517,18 @@ impl TopLevelComposer { } (constructor_args, type_vars) }; - let contor_type = unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: contor_args, ret: self_type, vars: contor_type_vars } - )); - unifier - .unify(constructor.unwrap(), contor_type) - .map_err(|e| e.at(Some(ast.as_ref().unwrap().location)).to_display(&unifier).to_string())?; + let contor_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: contor_args, + ret: self_type, + vars: contor_type_vars, + })); + unifier.unify(constructor.unwrap(), contor_type).map_err(|e| { + e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() + })?; // class field instantiation check if let (Some(init_id), false) = (init_id, fields.is_empty()) { - let init_ast = - definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); + let init_ast = definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { if *name != init_str_id { unreachable!("must be init function here") @@ -1490,9 +1559,13 @@ impl TopLevelComposer { } for (i, signature, id) in constructors.into_iter() { - if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() { - methods.push((init_str_id, signature, - DefinitionId(self.definition_ast_list.len() + id))); + if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() + { + methods.push(( + init_str_id, + signature, + DefinitionId(self.definition_ast_list.len() + id), + )); } else { unreachable!() } @@ -1508,7 +1581,7 @@ impl TopLevelComposer { let method_class = &mut self.method_class; let mut analyze_2 = |id, def: &Arc>, ast: &Option| { if ast.is_none() { - return Ok(()) + return Ok(()); } let mut function_def = def.write(); if let TopLevelDef::Function { @@ -1522,7 +1595,9 @@ impl TopLevelComposer { .. } = &mut *function_def { - if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = unifier.get_ty(*signature).as_ref() { + if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = + unifier.get_ty(*signature).as_ref() + { // None if is not class method let uninst_self_type = { if let Some(class_id) = method_class.get(&DefinitionId(id)) { @@ -1553,7 +1628,8 @@ impl TopLevelComposer { .iter() .map(|(_, ty)| { unifier.get_instantiations(*ty).unwrap_or_else(|| { - if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) { + if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) + { let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; no_ranges.push(rigid); vec![rigid] @@ -1588,33 +1664,32 @@ impl TopLevelComposer { .collect_vec() }; let self_type = { - uninst_self_type - .clone() - .map(|(self_type, type_vars)| { - let subst_for_self = { - let class_ty_var_ids = type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { - *id - } else { - unreachable!("must be type var here"); - } - }) - .collect::>(); - subst - .iter() - .filter_map(|(ty_var_id, ty_var_target)| { - if class_ty_var_ids.contains(ty_var_id) { - Some((*ty_var_id, *ty_var_target)) - } else { - None - } - }) - .collect::>() - }; - unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) - }) + uninst_self_type.clone().map(|(self_type, type_vars)| { + let subst_for_self = { + let class_ty_var_ids = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) + { + *id + } else { + unreachable!("must be type var here"); + } + }) + .collect::>(); + subst + .iter() + .filter_map(|(ty_var_id, ty_var_target)| { + if class_ty_var_ids.contains(ty_var_id) { + Some((*ty_var_id, *ty_var_target)) + } else { + None + } + }) + .collect::>() + }; + unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) + }) }; let mut identifiers = { // NOTE: none and function args? @@ -1632,9 +1707,7 @@ impl TopLevelComposer { defined_identifiers: identifiers.clone(), function_data: &mut FunctionData { resolver: resolver.as_ref().unwrap().clone(), - return_type: if unifier - .unioned(inst_ret, primitives_ty.none) - { + return_type: if unifier.unioned(inst_ret, primitives_ty.none) { None } else { Some(inst_ret) @@ -1656,7 +1729,7 @@ impl TopLevelComposer { primitives: primitives_ty, virtual_checks: &mut Vec::new(), calls: &mut calls, - in_handler: false + in_handler: false, }; let fun_body = @@ -1696,7 +1769,10 @@ impl TopLevelComposer { if let TypeEnum::TObj { obj_id, .. } = &*base { *obj_id } else { - return Err(format!("Base type should be a class (at {})", loc)) + return Err(format!( + "Base type should be a class (at {})", + loc + )); } }; let subtype_id = { @@ -1706,7 +1782,10 @@ impl TopLevelComposer { } else { let base_repr = inferencer.unifier.stringify(*base); let subtype_repr = inferencer.unifier.stringify(*subtype); - return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) + return Err(format!( + "Expected a subtype of {}, but got {} (at {})", + base_repr, subtype_repr, loc + )); } }; let subtype_entry = defs[subtype_id.0].read(); @@ -1716,7 +1795,10 @@ impl TopLevelComposer { if m.is_none() { let base_repr = inferencer.unifier.stringify(*base); let subtype_repr = inferencer.unifier.stringify(*subtype); - return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) + return Err(format!( + "Expected a subtype of {}, but got {} (at {})", + base_repr, subtype_repr, loc + )); } } else { unreachable!(); @@ -1748,12 +1830,7 @@ impl TopLevelComposer { } instance_to_stmt.insert( - get_subst_key( - unifier, - self_type, - &subst, - Some(insted_vars), - ), + get_subst_key(unifier, self_type, &subst, Some(insted_vars)), FunInstance { body: Arc::new(fun_body), unifier_id: 0, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21341fc..6c131e6 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,32 +1,22 @@ use std::convert::TryInto; -use nac3parser::ast::{Constant, Location}; use crate::symbol_resolver::SymbolValue; +use nac3parser::ast::{Constant, Location}; use super::*; impl TopLevelDef { - pub fn to_string( - &self, - unifier: &mut Unifier, - ) -> String - { + pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Class { - name, ancestors, fields, methods, type_vars, .. - } => { + TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { let fields_str = fields .iter() - .map(|(n, ty, _)| { - (n.to_string(), unifier.stringify(*ty)) - }) + .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); let methods_str = methods .iter() - .map(|(n, ty, id)| { - (n.to_string(), unifier.stringify(*ty), *id) - }) + .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", @@ -57,38 +47,38 @@ impl TopLevelComposer { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(2), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(4), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(6), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(7), @@ -102,8 +92,10 @@ impl TopLevelComposer { ("__param0__".into(), (int64, true)), ("__param1__".into(), (int64, true)), ("__param2__".into(), (int64, true)), - ].into_iter().collect::>().into(), - params: HashMap::new().into(), + ] + .into_iter() + .collect::>(), + params: HashMap::new(), }); let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); @@ -117,7 +109,7 @@ impl TopLevelComposer { resolver: Option>, name: StrRef, constructor: Option, - loc: Option + loc: Option, ) -> TopLevelDef { TopLevelDef::Class { name, @@ -138,7 +130,7 @@ impl TopLevelComposer { simple_name: StrRef, ty: Type, resolver: Option>, - loc: Option + loc: Option, ) -> TopLevelDef { TopLevelDef::Function { name, @@ -248,8 +240,11 @@ impl TopLevelComposer { let this = this.as_ref(); let other = unifier.get_ty(other); let other = other.as_ref(); - if let (TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, ..}), - TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. })) = (this, other) { + if let ( + TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), + TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), + ) = (this, other) + { // check args let args_ok = this_args .iter() @@ -362,11 +357,19 @@ impl TopLevelComposer { Ok(result) } - pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + pub fn parse_parameter_default_value( + default: &ast::Expr, + resolver: &(dyn SymbolResolver + Send + Sync), + ) -> Result { parse_parameter_default_value(default, resolver) } - pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> { + pub fn check_default_param_type( + val: &SymbolValue, + ty: &TypeAnnotation, + primitive: &PrimitiveStore, + unifier: &mut Unifier, + ) -> Result<(), String> { let res = match val { SymbolValue::Bool(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { @@ -430,33 +433,26 @@ impl TopLevelComposer { } } -pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { +pub fn parse_parameter_default_value( + default: &ast::Expr, + resolver: &(dyn SymbolResolver + Send + Sync), +) -> Result { fn handle_constant(val: &Constant, loc: &Location) -> Result { match val { - Constant::Int(v) => { - match v { - Some(v) => { - if let Ok(v) = (*v).try_into() { - Ok(SymbolValue::I32(v)) - } else { - Err(format!( - "integer value out of range at {}", - loc - )) - } - }, - None => { - Err(format!( - "integer value out of range at {}", - loc - )) + Constant::Int(v) => match v { + Some(v) => { + if let Ok(v) = (*v).try_into() { + Ok(SymbolValue::I32(v)) + } else { + Err(format!("integer value out of range at {}", loc)) } } - } + None => Err(format!("integer value out of range at {}", loc)), + }, Constant::Float(v) => Ok(SymbolValue::Double(*v)), Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( - tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()? + tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), _ => unimplemented!("this constant is not supported at {}", loc), } diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index e35846f..8232632 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -34,7 +34,10 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { - fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + fn get_default_param_value( + &self, + _: &nac3parser::ast::Expr, + ) -> Option { unimplemented!() } @@ -169,10 +172,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s { let def = &*def.read(); if let TopLevelDef::Function { signature, name, .. } = def { - let ty_str = - composer - .unifier - .internal_stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string(), &mut None); + let ty_str = composer.unifier.internal_stringify( + *signature, + &mut |id| id.to_string(), + &mut |id| id.to_string(), + &mut None, + ); assert_eq!(ty_str, tys[i]); assert_eq!(name, names[i]); } @@ -779,9 +784,12 @@ impl<'a> Fold> for TypeToStringFolder<'a> { type Error = String; fn map_user(&mut self, user: Option) -> Result { Ok(if let Some(ty) = user { - self.unifier.internal_stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| { - format!("tvar{}", id.to_string()) - }, &mut None) + self.unifier.internal_stringify( + ty, + &mut |id| format!("class{}", id.to_string()), + &mut |id| format!("tvar{}", id.to_string()), + &mut None, + ) } else { "None".into() }) diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 0064ac4..497218e 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -23,17 +23,27 @@ impl TypeAnnotation { Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), CustomClass { id, params } => { let class_name = match unifier.top_level { - Some(ref top) => if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() { - (*name).into() - } else { - format!("def_{}", id.0) + Some(ref top) => { + if let TopLevelDef::Class { name, .. } = + &*top.definitions.read()[id.0].read() + { + (*name).into() + } else { + format!("def_{}", id.0) + } } - None => format!("def_{}", id.0) + None => format!("def_{}", id.0), }; - format!("{{class: {}, params: {:?}}}", class_name, params.iter().map(|p| p.stringify(unifier)).collect_vec()) + format!( + "{{class: {}, params: {:?}}}", + class_name, + params.iter().map(|p| p.stringify(unifier)).collect_vec() + ) } Virtual(ty) | List(ty) => ty.stringify(unifier), - Tuple(types) => format!("({:?})", types.iter().map(|p| p.stringify(unifier)).collect_vec()), + Tuple(types) => { + format!("({:?})", types.iter().map(|p| p.stringify(unifier)).collect_vec()) + } } } } @@ -47,7 +57,9 @@ pub fn parse_ast_to_type_annotation_kinds( // the key stores the type_var of this topleveldef::class, we only need this field here locked: HashMap>, ) -> Result { - let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap>| { + let name_handle = |id: &StrRef, + unifier: &mut Unifier, + locked: HashMap>| { if id == &"int32".into() { Ok(TypeAnnotation::Primitive(primitives.int32)) } else if id == &"int64".into() { @@ -93,11 +105,7 @@ pub fn parse_ast_to_type_annotation_kinds( unifier.unify(var, ty).unwrap(); Ok(TypeAnnotation::TypeVar(ty)) } else { - Err(format!( - "`{}` is not a valid type annotation (at {})", - id, - expr.location - )) + Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) } } else { Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) @@ -105,73 +113,73 @@ pub fn parse_ast_to_type_annotation_kinds( }; let class_name_handle = - |id: &StrRef, slice: &ast::Expr, unifier: &mut Unifier, mut locked: HashMap>| { - if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] - .contains(id) + |id: &StrRef, + slice: &ast::Expr, + unifier: &mut Unifier, + mut locked: HashMap>| { + if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()].contains(id) { return Err(format!("keywords cannot be class name (at {})", expr.location)); } - let obj_id = resolver - .get_identifier_def(*id)?; - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() + let obj_id = resolver.get_identifier_def(*id)?; + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() + } else { + unreachable!("must be class here") + } } else { - unreachable!("must be class here") + locked.get(&obj_id).unwrap().clone() } - } else { - locked.get(&obj_id).unwrap().clone() - } - }; - // we do not check whether the application of type variables are compatible here - let param_type_infos = { - let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - elts.iter().collect_vec() - } else { - vec![slice] }; - if type_vars.len() != params_ast.len() { - return Err(format!( - "expect {} type parameters but got {} (at {})", - type_vars.len(), - params_ast.len(), - params_ast[0].location, - )); - } - let result = params_ast - .iter() - .map(|x| { - parse_ast_to_type_annotation_kinds( - resolver, - top_level_defs, - unifier, - primitives, - x, - { - locked.insert(obj_id, type_vars.clone()); - locked.clone() - }, - ) - }) - .collect::, _>>()?; - // make sure the result do not contain any type vars - let no_type_var = result - .iter() - .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); - if no_type_var { - result - } else { - return Err(format!( - "application of type vars to generic class \ + // we do not check whether the application of type variables are compatible here + let param_type_infos = { + let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter().collect_vec() + } else { + vec![slice] + }; + if type_vars.len() != params_ast.len() { + return Err(format!( + "expect {} type parameters but got {} (at {})", + type_vars.len(), + params_ast.len(), + params_ast[0].location, + )); + } + let result = params_ast + .iter() + .map(|x| { + parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + x, + { + locked.insert(obj_id, type_vars.clone()); + locked.clone() + }, + ) + }) + .collect::, _>>()?; + // make sure the result do not contain any type vars + let no_type_var = + result.iter().all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); + if no_type_var { + result + } else { + return Err(format!( + "application of type vars to generic class \ is not currently supported (at {})", - params_ast[0].location - )); - } + params_ast[0].location + )); + } + }; + Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) }; - Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) - }; match &expr.node { ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), // virtual @@ -297,8 +305,11 @@ pub fn get_type_from_type_annotation_kinds( let ok: bool = { // create a temp type var and unify to check compatibility p == *tvar || { - let temp = - unifier.get_fresh_var_with_range(range.as_slice(), *name, *loc); + let temp = unifier.get_fresh_var_with_range( + range.as_slice(), + *name, + *loc, + ); unifier.unify(temp.0, p).is_ok() } }; @@ -338,7 +349,7 @@ pub fn get_type_from_type_annotation_kinds( Ok(unifier.add_ty(TypeEnum::TObj { obj_id: *obj_id, fields: tobj_fields, - params: subst.into(), + params: subst, })) } } else { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index cced74e..95dd843 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -65,21 +65,19 @@ pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { } pub(super) fn with_fields(unifier: &mut Unifier, ty: Type, f: F) - where F: FnOnce(&mut Unifier, &mut HashMap) +where + F: FnOnce(&mut Unifier, &mut HashMap), { - let (id, mut fields, params) = if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) { - (*obj_id, fields.clone(), params.clone()) - } else { - unreachable!() - }; + let (id, mut fields, params) = + if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) { + (*obj_id, fields.clone(), params.clone()) + } else { + unreachable!() + }; f(unifier, &mut fields); unsafe { let unification_table = unifier.get_unification_table(); - unification_table.set_value(ty, Rc::new(TypeEnum::TObj { - obj_id: id, - fields, - params, - })); + unification_table.set_value(ty, Rc::new(TypeEnum::TObj { obj_id: id, fields, params })); } } @@ -106,34 +104,30 @@ pub fn impl_binop( for op in ops { fields.insert(binop_name(op).into(), { ( - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: ret_ty, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + ret: ret_ty, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + })), false, ) }); fields.insert(binop_assign_name(op).into(), { ( - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: store.none, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + ret: store.none, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + })), false, ) }); @@ -141,20 +135,17 @@ pub fn impl_binop( }); } -pub fn impl_unaryop( - unifier: &mut Unifier, - ty: Type, - ret_ty: Type, - ops: &[ast::Unaryop], -) { +pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::Unaryop]) { with_fields(unifier, ty, |unifier, fields| { for op in ops { fields.insert( unaryop_name(op).into(), ( - unifier.add_ty(TypeEnum::TFunc( - FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] } - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + ret: ret_ty, + vars: HashMap::new(), + args: vec![], + })), false, ), ); @@ -174,17 +165,15 @@ pub fn impl_cmpop( fields.insert( comparison_name(op).unwrap().into(), ( - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: store.bool, - vars: HashMap::new(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - } - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + ret: store.bool, + vars: HashMap::new(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + })), false, ), ); diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index 69b0ede..4cac1bf 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -1,6 +1,6 @@ mod function_check; pub mod magic_methods; +pub mod type_error; pub mod type_inferencer; pub mod typedef; -pub mod type_error; mod unification_table; diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index 3533d6c..bf96396 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -1,9 +1,9 @@ -use std::fmt::Display; use std::collections::HashMap; +use std::fmt::Display; use crate::typecheck::typedef::TypeEnum; -use super::typedef::{Type, Unifier, RecordKey}; +use super::typedef::{RecordKey, Type, Unifier}; use nac3parser::ast::{Location, StrRef}; #[derive(Debug, Clone)] @@ -53,16 +53,13 @@ impl TypeError { } pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError { - DisplayTypeError { - err: self, - unifier - } + DisplayTypeError { err: self, unifier } } } pub struct DisplayTypeError<'a> { pub err: TypeError, - pub unifier: &'a Unifier + pub unifier: &'a Unifier, } fn loc_to_str(loc: Option) -> String { @@ -86,11 +83,7 @@ impl<'a> Display for DisplayTypeError<'a> { UnknownArgName(name) => { write!(f, "Unknown argument name: {}", name) } - IncorrectArgType { - name, - expected, - got, - } => { + IncorrectArgType { name, expected, got } => { let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes); write!( @@ -98,19 +91,26 @@ impl<'a> Display for DisplayTypeError<'a> { "Incorrect argument type for {}. Expected {}, but got {}", name, expected, got ) - }, + } FieldUnificationError { field, types, loc } => { let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let rhs = self.unifier.stringify_with_notes(types.1, &mut notes); write!( f, "Unable to unify field {}: Got types {}{} and {}{}", - field, lhs, loc_to_str(loc.0), rhs, loc_to_str(loc.1) + field, + lhs, + loc_to_str(loc.0), + rhs, + loc_to_str(loc.1) ) } IncompatibleRange(t, ts) => { let t = self.unifier.stringify_with_notes(*t, &mut notes); - let ts = ts.iter().map(|t| self.unifier.stringify_with_notes(*t, &mut notes)).collect::>(); + let ts = ts + .iter() + .map(|t| self.unifier.stringify_with_notes(*t, &mut notes)) + .collect::>(); write!(f, "Expected any one of these types: {}, but got {}", ts.join(", "), t) } IncompatibleTypes(t1, t2) => { @@ -119,15 +119,21 @@ impl<'a> Display for DisplayTypeError<'a> { match (&*type1, &*type2) { (TypeEnum::TCall(calls), _) => { let loc = self.unifier.calls[calls[0].0].loc; - let result = write!(f, "{} is not callable", self.unifier.stringify_with_notes(*t2, &mut notes)); + let result = write!( + f, + "{} is not callable", + self.unifier.stringify_with_notes(*t2, &mut notes) + ); if let Some(loc) = loc { result?; write!(f, " (in {})", loc)?; - return Ok(()) + return Ok(()); } result } - (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) if ty1.len() != ty2.len() => { + (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) + if ty1.len() != ty2.len() => + { let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); write!(f, "Tuple length mismatch: got {} and {}", t1, t2) @@ -152,7 +158,11 @@ impl<'a> Display for DisplayTypeError<'a> { write!(f, "`{}::{}` field does not exist", t, name) } TupleIndexOutOfBounds { index, len } => { - write!(f, "Tuple index out of bounds. Got {} but tuple has only {} elements", index, len) + write!( + f, + "Tuple index out of bounds. Got {} but tuple has only {} elements", + index, len + ) } RequiresTypeAnn => { write!(f, "Unable to infer virtual object type: Type annotation required") @@ -174,4 +184,3 @@ impl<'a> Display for DisplayTypeError<'a> { Ok(()) } } - diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 11f86d8..9762fe0 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; -use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier, RecordField}; +use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; use itertools::izip; @@ -125,7 +125,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment } => { - let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + let body = body + .into_iter() + .map(|stmt| self.fold_stmt(stmt)) + .collect::, _>>()?; let outer_in_handler = self.in_handler; let mut exception_handlers = Vec::with_capacity(handlers.len()); self.in_handler = true; @@ -133,23 +136,29 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { let top_level_defs = self.top_level.definitions.read(); let mut naive_folder = NaiveFolder(); for handler in handlers.into_iter() { - let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = handler.node; + let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = + handler.node; let type_ = if let Some(type_) = type_ { let typ = self.function_data.resolver.parse_type_annotation( top_level_defs.as_slice(), self.unifier, self.primitives, - &type_ + &type_, )?; - self.virtual_checks.push((typ, self.primitives.exception, handler.location)); + self.virtual_checks.push(( + typ, + self.primitives.exception, + handler.location, + )); if let Some(name) = name { if !self.defined_identifiers.contains(&name) { self.defined_identifiers.insert(name); } if let Some(old_typ) = self.variable_mapping.insert(name, typ) { let loc = handler.location; - self.unifier.unify(old_typ, typ).map_err(|e| e.at(Some(loc)) - .to_display(self.unifier).to_string())?; + self.unifier.unify(old_typ, typ).map_err(|e| { + e.at(Some(loc)).to_display(self.unifier).to_string() + })?; } } let mut type_ = naive_folder.fold_expr(*type_)?; @@ -158,22 +167,32 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } else { None }; - let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + let body = body + .into_iter() + .map(|stmt| self.fold_stmt(stmt)) + .collect::, _>>()?; exception_handlers.push(Located { location: handler.location, node: ast::ExcepthandlerKind::ExceptHandler { type_, name, body }, - custom: None + custom: None, }); } } self.in_handler = outer_in_handler; let handlers = exception_handlers; - let orelse = orelse.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; - let finalbody = finalbody .into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + let orelse = orelse.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, + _, + >>( + )?; + let finalbody = finalbody + .into_iter() + .map(|stmt| self.fold_stmt(stmt)) + .collect::, _>>()?; Located { location: node.location, node: ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment }, - custom: None + custom: None, } } ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => { @@ -186,14 +205,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); self.unify(list, iter.custom.unwrap(), &iter.location)?; } - let body = body - .into_iter() - .map(|b| self.fold_stmt(b)) - .collect::, _>>()?; - let orelse = orelse - .into_iter() - .map(|o| self.fold_stmt(o)) - .collect::, _>>()?; + let body = + body.into_iter().map(|b| self.fold_stmt(b)).collect::, _>>()?; + let orelse = + orelse.into_iter().map(|o| self.fold_stmt(o)).collect::, _>>()?; Located { location: node.location, node: ast::StmtKind::For { @@ -204,7 +219,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { config_comment, type_comment, }, - custom: None + custom: None, } } ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { @@ -252,7 +267,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { }) .collect(); let loc = node.location; - let targets = targets.map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?; + let targets = targets + .map_err(|e| e.at(Some(loc)).to_display(self.unifier).to_string())?; return Ok(Located { location: node.location, node: ast::StmtKind::Assign { @@ -283,8 +299,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { _ => fold::fold_stmt(self, node)?, }; match &stmt.node { - ast::StmtKind::For { .. } => {}, - ast::StmtKind::Try { .. } => {}, + ast::StmtKind::For { .. } => {} + ast::StmtKind::Try { .. } => {} ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; } @@ -302,9 +318,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { return report_error("raise ... from cause is not supported", cause.location); } if let Some(exc) = exc { - self.virtual_checks.push((exc.custom.unwrap(), self.primitives.exception, exc.location)); + self.virtual_checks.push(( + exc.custom.unwrap(), + self.primitives.exception, + exc.location, + )); } else if !self.in_handler { - return report_error("cannot reraise outside exception handlers", stmt.location); + return report_error( + "cannot reraise outside exception handlers", + stmt.location, + ); } } ast::StmtKind::With { items, .. } => { @@ -419,8 +442,9 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { _ => fold::fold_expr(self, node)?, }; let custom = match &expr.node { - ast::ExprKind::Constant { value, .. } => - Some(self.infer_constant(value, &expr.location)?), + ast::ExprKind::Constant { value, .. } => { + Some(self.infer_constant(value, &expr.location)?) + } ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { match self.function_data.resolver.get_symbol_type( @@ -481,7 +505,9 @@ impl<'a> Inferencer<'a> { } fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), String> { - self.unifier.unify(a, b).map_err(|e| e.at(Some(*location)).to_display(self.unifier).to_string()) + self.unifier + .unify(a, b) + .map_err(|e| e.at(Some(*location)).to_display(self.unifier).to_string()) } fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> { @@ -533,9 +559,9 @@ impl<'a> Inferencer<'a> { .map(|v| v.name) .rev() .collect(); - self.unifier - .unify_call(&call, ty, sign, &required) - .map_err(|e| e.at(Some(location)).to_display(self.unifier).to_string())?; + self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| { + e.at(Some(location)).to_display(self.unifier).to_string() + })?; return Ok(sign.ret); } } @@ -585,8 +611,11 @@ impl<'a> Inferencer<'a> { defined_identifiers.insert(*name); } } - let fn_args: Vec<_> = - args.args.iter().map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)).collect(); + let fn_args: Vec<_> = args + .args + .iter() + .map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)) + .collect(); let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().cloned()); let ret = self.unifier.get_dummy_var().0; @@ -649,7 +678,7 @@ impl<'a> Inferencer<'a> { calls: self.calls, defined_identifiers, // listcomp expr should not be considered as inside an exception handler... - in_handler: false + in_handler: false, }; let generator = generators.pop().unwrap(); if generator.is_async { @@ -784,7 +813,7 @@ impl<'a> Inferencer<'a> { .collect(), fun: RefCell::new(None), ret: sign.ret, - loc: Some(location) + loc: Some(location), }; let required: Vec<_> = sign .args @@ -813,7 +842,7 @@ impl<'a> Inferencer<'a> { .collect(), fun: RefCell::new(None), ret, - loc: Some(location) + loc: Some(location), }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); @@ -853,8 +882,8 @@ impl<'a> Inferencer<'a> { } else { report_error("Integer out of bound", *loc) } - }, - None => report_error("Integer out of bound", *loc) + } + None => report_error("Integer out of bound", *loc), } } ast::Constant::Float(_) => Ok(self.primitives.float), @@ -900,8 +929,11 @@ impl<'a> Inferencer<'a> { } } else { let attr_ty = self.unifier.get_dummy_var().0; - let fields = once((attr.into(), RecordField::new( - attr_ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); + let fields = once(( + attr.into(), + RecordField::new(attr_ty, ctx == &ExprContext::Store, Some(value.location)), + )) + .collect(); let record = self.unifier.add_record(fields); self.constrain(value.custom.unwrap(), record, &value.location)?; Ok(attr_ty) @@ -986,8 +1018,11 @@ impl<'a> Inferencer<'a> { None => None, }; let ind = ind.ok_or_else(|| "Index must be int32".to_string())?; - let map = once((ind.into(), RecordField::new( - ty, ctx == &ExprContext::Store, Some(value.location)))).collect(); + let map = once(( + ind.into(), + RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), + )) + .collect(); let seq = self.unifier.add_record(map); self.constrain(value.custom.unwrap(), seq, &value.location)?; Ok(ty) diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index c922ed1..dae4ad9 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,4 +1,4 @@ -use super::super::{typedef::*, magic_methods::with_fields}; +use super::super::{magic_methods::with_fields, typedef::*}; use super::*; use crate::{ codegen::CodeGenContext, @@ -18,7 +18,10 @@ struct Resolver { } impl SymbolResolver for Resolver { - fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + fn get_default_param_value( + &self, + _: &nac3parser::ast::Expr, + ) -> Option { unimplemented!() } @@ -66,54 +69,51 @@ impl TestEnvironment { let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); with_fields(&mut unifier, int32, |unifier, fields| { - let add_ty = unifier.add_ty(TypeEnum::TFunc( - FunSignature { - args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], - ret: int32, - vars: HashMap::new(), - } - .into(), - )); + let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], + ret: int32, + vars: HashMap::new(), + })); fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(2), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(4), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(6), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(7), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; set_primitives_magic_methods(&primitives, &mut unifier); @@ -167,58 +167,56 @@ impl TestEnvironment { let mut top_level_defs: Vec>> = Vec::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); with_fields(&mut unifier, int32, |unifier, fields| { - let add_ty = unifier.add_ty(TypeEnum::TFunc( - FunSignature { - args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], - ret: int32, - vars: HashMap::new(), - } - .into(), - )); + let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], + ret: int32, + vars: HashMap::new(), + })); fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(2), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(4), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(6), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(7), - fields: HashMap::new().into(), - params: HashMap::new().into(), + fields: HashMap::new(), + params: HashMap::new(), }); identifier_mapping.insert("None".into(), none); - for (i, name) in - ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"].iter().enumerate() + for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] + .iter() + .enumerate() { top_level_defs.push( RwLock::new(TopLevelDef::Class { @@ -230,7 +228,7 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, - loc: None + loc: None, }) .into(), ); @@ -243,8 +241,8 @@ impl TestEnvironment { let foo_ty = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(defs + 1), - fields: [("a".into(), (v0, true))].iter().cloned().collect::>().into(), - params: [(id, v0)].iter().cloned().collect::>().into(), + fields: [("a".into(), (v0, true))].iter().cloned().collect::>(), + params: [(id, v0)].iter().cloned().collect::>(), }); top_level_defs.push( RwLock::new(TopLevelDef::Class { @@ -263,26 +261,24 @@ impl TestEnvironment { identifier_mapping.insert( "Foo".into(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - args: vec![], - ret: foo_ty, - vars: [(id, v0)].iter().cloned().collect(), - } - .into(), - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: foo_ty, + vars: [(id, v0)].iter().cloned().collect(), + })), ); - let fun = unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(), - )); + let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: int32, + vars: Default::default(), + })); let bar = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(defs + 2), fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))] .iter() .cloned() - .collect::>() - .into(), + .collect::>(), params: Default::default(), }); top_level_defs.push( @@ -295,15 +291,17 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, - loc: None + loc: None, }) .into(), ); identifier_mapping.insert( "Bar".into(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(), - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: bar, + vars: Default::default(), + })), ); let bar2 = unifier.add_ty(TypeEnum::TObj { @@ -311,8 +309,7 @@ impl TestEnvironment { fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))] .iter() .cloned() - .collect::>() - .into(), + .collect::>(), params: Default::default(), }); top_level_defs.push( @@ -325,15 +322,17 @@ impl TestEnvironment { ancestors: Default::default(), resolver: None, constructor: None, - loc: None + loc: None, }) .into(), ); identifier_mapping.insert( "Bar2".into(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(), - )), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: bar2, + vars: Default::default(), + })), ); let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); @@ -400,7 +399,7 @@ impl TestEnvironment { virtual_checks: &mut self.virtual_checks, calls: &mut self.calls, defined_identifiers: Default::default(), - in_handler: false + in_handler: false, } } } @@ -493,7 +492,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st *v, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); println!("{}: {}", k, name); } @@ -503,7 +502,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st *ty, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } @@ -513,13 +512,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st *a, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); let b = inferencer.unifier.internal_stringify( *b, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); assert_eq!(&a, x); @@ -639,7 +638,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { *v, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); println!("{}: {}", k, name); } @@ -649,7 +648,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { *ty, &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), - &mut None + &mut None, ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 516191d..579d80a 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -6,10 +6,10 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use nac3parser::ast::{StrRef, Location}; +use nac3parser::ast::{Location, StrRef}; -use super::unification_table::{UnificationKey, UnificationTable}; use super::type_error::{TypeError, TypeErrorKind}; +use super::unification_table::{UnificationKey, UnificationTable}; use crate::symbol_resolver::SymbolValue; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; @@ -51,14 +51,14 @@ pub struct FunSignature { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum RecordKey { Str(StrRef), - Int(i32) + Int(i32), } impl From<&RecordKey> for StrRef { fn from(r: &RecordKey) -> Self { match r { RecordKey::Str(s) => *s, - RecordKey::Int(i) => StrRef::from(i.to_string()) + RecordKey::Int(i) => StrRef::from(i.to_string()), } } } @@ -85,7 +85,7 @@ impl Display for RecordKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { RecordKey::Str(s) => write!(f, "{}", s), - RecordKey::Int(i) => write!(f, "{}", i) + RecordKey::Int(i) => write!(f, "{}", i), } } } @@ -94,7 +94,7 @@ impl Display for RecordKey { pub struct RecordField { ty: Type, mutable: bool, - loc: Option + loc: Option, } impl RecordField { @@ -108,7 +108,7 @@ pub enum TypeEnum { TRigidVar { id: u32, name: Option, - loc: Option + loc: Option, }, TVar { id: u32, @@ -117,7 +117,7 @@ pub enum TypeEnum { // empty indicates no restriction range: Vec, name: Option, - loc: Option + loc: Option, }, TTuple { ty: Vec, @@ -264,7 +264,11 @@ impl Unifier { self.unification_table.probe_value_immutable(a).clone() } - pub fn get_fresh_rigid_var(&mut self, name: Option, loc: Option) -> (Type, u32) { + pub fn get_fresh_rigid_var( + &mut self, + name: Option, + loc: Option, + ) -> (Type, u32) { let id = self.var_id + 1; self.var_id += 1; (self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id) @@ -279,11 +283,16 @@ impl Unifier { } /// Get a fresh type variable. - pub fn get_fresh_var_with_range(&mut self, range: &[Type], name: Option, loc: Option) -> (Type, u32) { + pub fn get_fresh_var_with_range( + &mut self, + range: &[Type], + name: Option, + loc: Option, + ) -> (Type, u32) { let id = self.var_id + 1; self.var_id += 1; let range = range.to_vec(); - (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc}), id) + (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id) } /// Unification would not unify rigid variables with other types, but we want to do this for @@ -344,8 +353,7 @@ impl Unifier { .map(|params| { self.subst( ty, - &zip(keys.iter().cloned(), params.iter().cloned()) - .collect(), + &zip(keys.iter().cloned(), params.iter().cloned()).collect(), ) .unwrap_or(ty) }) @@ -395,23 +403,19 @@ impl Unifier { // we check to make sure that all required arguments (those without default // arguments) are provided, and do not provide the same argument twice. let mut required = required.to_vec(); - let mut all_names: Vec<_> = - signature.args.iter().map(|v| (v.name, v.ty)).rev().collect(); + let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect(); for (i, t) in posargs.iter().enumerate() { if signature.args.len() <= i { - return Err(TypeError::new(TypeErrorKind::TooManyArguments{ - expected: signature.args.len(), - got: i, - }, *loc)); + return Err(TypeError::new( + TypeErrorKind::TooManyArguments { expected: signature.args.len(), got: i }, + *loc, + )); } required.pop(); let (name, expected) = all_names.pop().unwrap(); - self.unify_impl(expected, *t, false) - .map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { - name, - expected, - got: *t, - }, *loc))?; + self.unify_impl(expected, *t, false).map_err(|_| { + TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) + })?; } for (k, t) in kwargs.iter() { if let Some(i) = required.iter().position(|v| v == k) { @@ -422,23 +426,22 @@ impl Unifier { .position(|v| &v.0 == k) .ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; let (name, expected) = all_names.remove(i); - self.unify_impl(expected, *t, false) - .map_err(|_| TypeError::new(TypeErrorKind::IncorrectArgType { - name, - expected, - got: *t, - }, *loc))?; + self.unify_impl(expected, *t, false).map_err(|_| { + TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) + })?; } if !required.is_empty() { - return Err(TypeError::new(TypeErrorKind::MissingArgs(required.iter().join(", ")), *loc)); + return Err(TypeError::new( + TypeErrorKind::MissingArgs(required.iter().join(", ")), + *loc, + )); } - self.unify_impl(*ret, signature.ret, false) - .map_err(|mut err| { - if err.loc.is_none() { - err.loc = *loc; - } - err - })?; + self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { + if err.loc.is_none() { + err.loc = *loc; + } + err + })?; *fun.borrow_mut() = Some(instantiated); Ok(()) } @@ -471,24 +474,38 @@ impl Unifier { ) }; match (&*ty_a, &*ty_b) { - (TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields2, name: name2, loc: loc2, .. }) => { + ( + TVar { fields: fields1, id, name: name1, loc: loc1, .. }, + TVar { fields: fields2, name: name2, loc: loc2, .. }, + ) => { let new_fields = match (fields1, fields2) { (None, None) => None, (None, Some(fields)) => Some(fields.clone()), (_, None) => { return self.unify_impl(b, a, true); - }, + } (Some(fields1), Some(fields2)) => { let mut new_fields: Mapping<_, _> = fields2.clone(); for (key, val1) in fields1.iter() { if let Some(val2) = fields2.get(key) { - self.unify_impl(val1.ty, val2.ty, false) - .map_err(|_| TypeError::new(TypeErrorKind::FieldUnificationError { - field: *key, - types: (val1.ty, val2.ty), - loc: (*loc1, *loc2), - }, None))?; - new_fields.insert(*key, RecordField::new(val1.ty, val1.mutable || val2.mutable, val1.loc.or(val2.loc))); + self.unify_impl(val1.ty, val2.ty, false).map_err(|_| { + TypeError::new( + TypeErrorKind::FieldUnificationError { + field: *key, + types: (val1.ty, val2.ty), + loc: (*loc1, *loc2), + }, + None, + ) + })?; + new_fields.insert( + *key, + RecordField::new( + val1.ty, + val1.mutable || val2.mutable, + val1.loc.or(val2.loc), + ), + ); } else { new_fields.insert(*key, *val1); } @@ -496,21 +513,26 @@ impl Unifier { Some(new_fields) } }; - let intersection = self.get_intersection(a, b).map_err(|_| - TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?.unwrap(); + let intersection = self + .get_intersection(a, b) + .map_err(|_| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))? + .unwrap(); let range = if let TypeEnum::TVar { range, .. } = &*self.get_ty(intersection) { range.clone() } else { unreachable!() }; self.unification_table.unify(a, b); - self.unification_table.set_value(a, Rc::new(TypeEnum::TVar { - id: *id, - fields: new_fields, - range, - name: name1.or(*name2), - loc: loc1.or(*loc2) - })); + self.unification_table.set_value( + a, + Rc::new(TypeEnum::TVar { + id: *id, + fields: new_fields, + range, + name: name1.or(*name2), + loc: loc1.or(*loc2), + }), + ); } (TVar { fields: None, range, .. }, _) => { // We check for the range of the type variable to see if unification is allowed. @@ -520,8 +542,12 @@ impl Unifier { // The return value x of check_var_compatibility would be a new type that is // guaranteed to be compatible with a under all possible instantiations. So we // unify x with b to recursively apply the constrains, and then set a to x. - let x = self.check_var_compatibility(b, range).map_err(|_| - TypeError::new(TypeErrorKind::IncompatibleRange(b, range.clone()), None))?.unwrap_or(b); + let x = self + .check_var_compatibility(b, range) + .map_err(|_| { + TypeError::new(TypeErrorKind::IncompatibleRange(b, range.clone()), None) + })? + .unwrap_or(b); self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } @@ -532,17 +558,23 @@ impl Unifier { RecordKey::Int(i) => { if v.mutable { return Err(TypeError::new( - TypeErrorKind::MutationError(*k, b), v.loc)); + TypeErrorKind::MutationError(*k, b), + v.loc, + )); } let ind = if i < 0 { len + i } else { i }; if ind >= len || ind < 0 { return Err(TypeError::new( - TypeErrorKind::TupleIndexOutOfBounds{ index: i, len}, v.loc)); + TypeErrorKind::TupleIndexOutOfBounds { index: i, len }, + v.loc, + )); } - self.unify_impl(v.ty, ty[ind as usize], false).map_err(|e| e.at(v.loc))?; + self.unify_impl(v.ty, ty[ind as usize], false) + .map_err(|e| e.at(v.loc))?; + } + RecordKey::Str(_) => { + return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) } - RecordKey::Str(_) => return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), v.loc)), } } let x = self.check_var_compatibility(b, range)?.unwrap_or(b); @@ -552,9 +584,12 @@ impl Unifier { (TVar { fields: Some(fields), range, .. }, TList { ty }) => { for (k, v) in fields.iter() { match *k { - RecordKey::Int(_) => self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?, - RecordKey::Str(_) => return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), v.loc)), + RecordKey::Int(_) => { + self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))? + } + RecordKey::Str(_) => { + return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) + } } } let x = self.check_var_compatibility(b, range)?.unwrap_or(b); @@ -578,23 +613,26 @@ impl Unifier { for (k, field) in map.iter() { match *k { RecordKey::Str(s) => { - let (ty, mutable) = fields - .get(&s) - .copied() - .ok_or_else(|| TypeError::new( - TypeErrorKind::NoSuchField(*k, b), field.loc))?; + let (ty, mutable) = fields.get(&s).copied().ok_or_else(|| { + TypeError::new(TypeErrorKind::NoSuchField(*k, b), field.loc) + })?; // typevar represents the usage of the variable // it is OK to have immutable usage for mutable fields // but cannot have mutable usage for immutable fields - if field.mutable && !mutable{ + if field.mutable && !mutable { return Err(TypeError::new( - TypeErrorKind::MutationError(*k, b), field.loc)); + TypeErrorKind::MutationError(*k, b), + field.loc, + )); } - self.unify_impl(field.ty, ty, false) - .map_err(|v| v.at(field.loc))?; + self.unify_impl(field.ty, ty, false).map_err(|v| v.at(field.loc))?; + } + RecordKey::Int(_) => { + return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), + field.loc, + )) } - RecordKey::Int(_) => return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), field.loc)) } } let x = self.check_var_compatibility(b, range)?.unwrap_or(b); @@ -607,29 +645,35 @@ impl Unifier { for (k, field) in map.iter() { match *k { RecordKey::Str(s) => { - let (ty, _) = fields - .get(&s) - .copied() - .ok_or_else(|| TypeError::new( - TypeErrorKind::NoSuchField(*k, b), field.loc))?; + let (ty, _) = fields.get(&s).copied().ok_or_else(|| { + TypeError::new(TypeErrorKind::NoSuchField(*k, b), field.loc) + })?; if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), field.loc)) + TypeErrorKind::NoSuchField(*k, b), + field.loc, + )); } if field.mutable { return Err(TypeError::new( - TypeErrorKind::MutationError(*k, b), field.loc)); + TypeErrorKind::MutationError(*k, b), + field.loc, + )); } self.unify_impl(field.ty, ty, false) .map_err(|v| v.at(field.loc))?; } - RecordKey::Int(_) => return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), field.loc)) + RecordKey::Int(_) => { + return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), + field.loc, + )) + } } } } else { // require annotation... - return Err(TypeError::new(TypeErrorKind::RequiresTypeAnn, None)) + return Err(TypeError::new(TypeErrorKind::RequiresTypeAnn, None)); } let x = self.check_var_compatibility(b, range)?.unwrap_or(b); self.unify_impl(x, b, false)?; @@ -708,7 +752,11 @@ impl Unifier { self.stringify_with_notes(ty, &mut None) } - pub fn stringify_with_notes(&self, ty: Type, notes: &mut Option>) -> String { + pub fn stringify_with_notes( + &self, + ty: Type, + notes: &mut Option>, + ) -> String { let top_level = self.top_level.clone(); self.internal_stringify( ty, @@ -727,51 +775,82 @@ impl Unifier { ) }, &mut |id| format!("var{}", id), - notes + notes, ) } /// Get string representation of the type - pub fn internal_stringify(&self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G, notes: &mut Option>) -> String + pub fn internal_stringify( + &self, + ty: Type, + obj_to_name: &mut F, + var_to_name: &mut G, + notes: &mut Option>, + ) -> String where F: FnMut(usize) -> String, G: FnMut(u32) -> String, { let ty = self.unification_table.probe_value_immutable(ty).clone(); match ty.as_ref() { - TypeEnum::TRigidVar { id, name, .. } => name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), + TypeEnum::TRigidVar { id, name, .. } => { + name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)) + } TypeEnum::TVar { id, name, fields, range, .. } => { let n = if let Some(fields) = fields { - let mut fields = fields.iter().map(|(k, f)| format!("{}={}", k, self.internal_stringify(f.ty, obj_to_name, var_to_name, notes))); + let mut fields = fields.iter().map(|(k, f)| { + format!( + "{}={}", + k, + self.internal_stringify(f.ty, obj_to_name, var_to_name, notes) + ) + }); let fields = fields.join(", "); - format!("{}[{}]", name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), fields) + format!( + "{}[{}]", + name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)), + fields + ) } else { name.map(|v| v.to_string()).unwrap_or_else(|| var_to_name(*id)) }; - if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id) { + if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id) + { // just in case if there is any cyclic dependency notes.as_mut().unwrap().insert(*id, "".into()); - let body = format!("{} ∈ {{{}}}", n, range.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)).collect::>().join(", ")); + let body = format!( + "{} ∈ {{{}}}", + n, + range + .iter() + .map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)) + .collect::>() + .join(", ") + ); notes.as_mut().unwrap().insert(*id, body); }; n } TypeEnum::TTuple { ty } => { - let mut fields = ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); + let mut fields = + ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); format!("tuple[{}]", fields.join(", ")) } TypeEnum::TList { ty } => { format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } TypeEnum::TVirtual { ty } => { - format!("virtual[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) + format!( + "virtual[{}]", + self.internal_stringify(*ty, obj_to_name, var_to_name, notes) + ) } TypeEnum::TObj { obj_id, params, .. } => { let name = obj_to_name(obj_id.0); if !params.is_empty() { - let params = params.iter().map(|(_, v)| { - self.internal_stringify(*v, obj_to_name, var_to_name, notes) - }); + let params = params + .iter() + .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); // sort to preserve order let mut params = params.sorted(); format!("{}[{}]", name, params.join(", ")) @@ -786,9 +865,18 @@ impl Unifier { .iter() .map(|arg| { if let Some(dv) = &arg.default_value { - format!("{}:{}={}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), dv) + format!( + "{}:{}={}", + arg.name, + self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), + dv + ) } else { - format!("{}:{}", arg.name, self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)) + format!( + "{}:{}", + arg.name, + self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes) + ) } }) .join(", "); @@ -834,7 +922,9 @@ impl Unifier { } else { let mapping = vars .into_iter() - .map(|(k, range, name, loc)| (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0)) + .map(|(k, range, name, loc)| { + (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0) + }) .collect(); self.subst(ty, &mapping).unwrap_or(ty) } @@ -907,9 +997,8 @@ impl Unifier { let obj_id = *obj_id; let params = self.subst_map(params, mapping, cache).unwrap_or_else(|| params.clone()); - let fields = self - .subst_map2(fields, mapping, cache) - .unwrap_or_else(|| fields.clone()); + let fields = + self.subst_map2(fields, mapping, cache).unwrap_or_else(|| fields.clone()); let new_ty = self.add_ty(TypeEnum::TObj { obj_id, params, fields }); if let Some(var) = cache.get(&a).unwrap() { self.unify_impl(new_ty, *var, false).unwrap(); @@ -934,7 +1023,7 @@ impl Unifier { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.into_owned(); - Some( self.add_ty(TypeEnum::TFunc( FunSignature { args, ret, vars: params })),) + Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params }))) } else { None } @@ -992,7 +1081,10 @@ impl Unifier { let x = self.get_ty(a); let y = self.get_ty(b); match (x.as_ref(), y.as_ref()) { - (TVar { range: range1, name, loc, .. }, TVar { fields, range: range2, name: name2, loc: loc2, .. }) => { + ( + TVar { range: range1, name, loc, .. }, + TVar { fields, range: range2, name: name2, loc: loc2, .. }, + ) => { // new range is the intersection of them // empty range indicates no constraint if range1.is_empty() { @@ -1000,14 +1092,25 @@ impl Unifier { } else if range2.is_empty() { Ok(Some(a)) } else { - let range = range2.iter().cartesian_product(range1.iter()) - .filter_map(|(v1, v2)| self.get_intersection(*v1, *v2).map(|v| v.unwrap_or(*v1)).ok()).collect_vec(); + let range = range2 + .iter() + .cartesian_product(range1.iter()) + .filter_map(|(v1, v2)| { + self.get_intersection(*v1, *v2).map(|v| v.unwrap_or(*v1)).ok() + }) + .collect_vec(); if range.is_empty() { Err(()) } else { let id = self.var_id + 1; self.var_id += 1; - let ty = TVar { id, fields: fields.clone(), range, name: name2.or(*name), loc: loc2.or(*loc) }; + let ty = TVar { + id, + fields: fields.clone(), + range, + name: name2.or(*name), + loc: loc2.or(*loc), + }; Ok(Some(self.unification_table.new_key(ty.into()))) } } @@ -1026,13 +1129,15 @@ impl Unifier { Err(()) } } - (TVar { range, .. }, _) => { - self.check_var_compatibility(b, range).or(Err(())) - } + (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())), (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => { - let ty: Vec<_> = zip(ty1.iter(), ty2.iter()).map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?; + let ty: Vec<_> = zip(ty1.iter(), ty2.iter()) + .map(|(a, b)| self.get_intersection(*a, *b)) + .try_collect()?; if ty.iter().any(Option::is_some) { - Ok(Some(self.add_ty(TTuple { ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect()}))) + Ok(Some(self.add_ty(TTuple { + ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(), + }))) } else { Ok(None) } @@ -1043,9 +1148,7 @@ impl Unifier { (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } - (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) if id1 == id2 => { - Ok(None) - } + (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) if id1 == id2 => Ok(None), // don't deal with function shape for now _ => Err(()), } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 56e0fd3..eddf445 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -1,5 +1,5 @@ -use super::*; use super::super::magic_methods::with_fields; +use super::*; use indoc::indoc; use itertools::Itertools; use std::collections::HashMap; @@ -115,10 +115,7 @@ impl TestEnvironment { "Foo".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: [("a".into(), (v0, true))] - .iter() - .cloned() - .collect::>(), + fields: [("a".into(), (v0, true))].iter().cloned().collect::>(), params: [(id, v0)].iter().cloned().collect::>(), }), ); @@ -365,9 +362,11 @@ fn test_recursive_subst() { fn test_virtual() { let mut env = TestEnvironment::new(); let int = env.parse("int", &HashMap::new()); - let fun = env.unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: vec![], ret: int, vars: HashMap::new() }, - )); + let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: int, + vars: HashMap::new(), + })); let bar = env.unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] @@ -381,15 +380,21 @@ fn test_virtual() { let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); - let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect()); + let c = env + .unifier + .add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect()); env.unifier.unify(a, b).unwrap(); env.unifier.unify(b, c).unwrap(); assert!(env.unifier.eq(v1, fun)); - let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); + let d = env + .unifier + .add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field does not exist".to_string())); - let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); + let d = env + .unifier + .add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect()); assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field does not exist".to_string())); } @@ -451,10 +456,7 @@ fn test_typevar_range() { let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; env.unifier.unify(a, b).unwrap(); - assert_eq!( - env.unify(a, int), - Err("Expected any one of these types: 1, but got 0".into()) - ); + assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into())); let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; @@ -556,9 +558,12 @@ fn test_instantiation() { let types = types .iter() .map(|ty| { - env.unifier.internal_stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { - format!("v{}", i) - }, &mut None) + env.unifier.internal_stringify( + *ty, + &mut |i| obj_map.get(&i).unwrap().to_string(), + &mut |i| format!("v{}", i), + &mut None, + ) }) .sorted() .collect_vec(); diff --git a/nac3standalone/demo/demo.rs b/nac3standalone/demo/demo.rs index 37ce8c6..01dc458 100644 --- a/nac3standalone/demo/demo.rs +++ b/nac3standalone/demo/demo.rs @@ -1,4 +1,5 @@ -mod cslice { // copied from https://github.com/dherman/cslice +mod cslice { + // copied from https://github.com/dherman/cslice use std::marker::PhantomData; use std::slice; @@ -7,14 +8,12 @@ mod cslice { // copied from https://github.com/dherman/cslice pub struct CSlice<'a, T> { base: *const T, len: usize, - marker: PhantomData<&'a ()> + marker: PhantomData<&'a ()>, } impl<'a, T> AsRef<[T]> for CSlice<'a, T> { fn as_ref(&self) -> &[T] { - unsafe { - slice::from_raw_parts(self.base, self.len) - } + unsafe { slice::from_raw_parts(self.base, self.len) } } } } @@ -43,7 +42,7 @@ pub extern "C" fn output_asciiart(x: i32) { pub extern "C" fn output_int32_list(x: &cslice::CSlice) { print!("["); let mut it = x.as_ref().iter().peekable(); - while let Some(e) = it.next() { + while let Some(e) = it.next() { if it.peek().is_none() { print!("{}", e); } else { @@ -58,7 +57,6 @@ pub extern "C" fn __artiq_personality(_state: u32, _exception_object: u32, _cont unimplemented!(); } - extern "C" { fn run() -> i32; } diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 3d1fbaa..517c848 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -38,10 +38,8 @@ pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { fn get_default_param_value(&self, expr: &ast::Expr) -> Option { match &expr.node { - ast::ExprKind::Name { id, .. } => { - self.0.module_globals.lock().get(id).cloned() - } - _ => unimplemented!("other type of expr not supported at {}", expr.location) + ast::ExprKind::Name { id, .. } => self.0.module_globals.lock().get(id).cloned(), + _ => unimplemented!("other type of expr not supported at {}", expr.location), } } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 3d9dada..bf99caf 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,24 +1,30 @@ use inkwell::{ + memory_buffer::MemoryBuffer, passes::{PassManager, PassManagerBuilder}, targets::*, - OptimizationLevel, memory_buffer::MemoryBuffer, + OptimizationLevel, }; +use parking_lot::{Mutex, RwLock}; use std::{borrow::Borrow, collections::HashMap, env, fs, path::Path, sync::Arc}; -use parking_lot::{RwLock, Mutex}; -use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser}; use nac3core::{ codegen::{ - concrete_type::ConcreteTypeStore, CodeGenTask, DefaultCodeGenerator, WithCall, - WorkerRegistry, irrt::load_irrt, + concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenTask, DefaultCodeGenerator, + WithCall, WorkerRegistry, }, symbol_resolver::SymbolResolver, toplevel::{ - composer::TopLevelComposer, - TopLevelDef, helper::parse_parameter_default_value, - type_annotation::*, + composer::TopLevelComposer, helper::parse_parameter_default_value, type_annotation::*, + TopLevelDef, }, - typecheck::{type_inferencer::PrimitiveStore, typedef::{Type, Unifier, FunSignature}} + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{FunSignature, Type, Unifier}, + }, +}; +use nac3parser::{ + ast::{Expr, ExprKind, StmtKind}, + parser, }; mod basic_symbol_resolver; @@ -26,10 +32,7 @@ use basic_symbol_resolver::*; fn main() { let file_name = env::args().nth(1).unwrap(); - let threads: u32 = env::args() - .nth(2) - .map(|s| str::parse(&s).unwrap()) - .unwrap_or(1); + let threads: u32 = env::args().nth(2).map(|s| str::parse(&s).unwrap()).unwrap_or(1); Target::initialize_all(&InitializationConfig::default()); @@ -42,10 +45,8 @@ fn main() { }; let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; - let (mut composer, builtins_def, builtins_ty) = TopLevelComposer::new( - vec![], - Default::default() - ); + let (mut composer, builtins_def, builtins_ty) = + TopLevelComposer::new(vec![], Default::default()); let internal_resolver: Arc = ResolverInternal { id_to_type: builtins_ty.into(), @@ -83,15 +84,23 @@ fn main() { x, Default::default(), )?; - get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty) + get_type_from_type_annotation_kinds( + def_list, unifier, primitives, &ty, + ) }) .collect::, _>>()?; Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) } else { - Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) + Err(format!( + "expression {:?} cannot be handled as a TypeVar in global scope", + var + )) } } else { - Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) + Err(format!( + "expression {:?} cannot be handled as a TypeVar in global scope", + var + )) } } @@ -116,7 +125,9 @@ fn main() { ) { internal_resolver.add_id_type(*id, var); Ok(()) - } else if let Ok(val) = parse_parameter_default_value(value.borrow(), resolver) { + } else if let Ok(val) = + parse_parameter_default_value(value.borrow(), resolver) + { internal_resolver.add_module_global(*id, val); Ok(()) } else { @@ -126,8 +137,7 @@ fn main() { )) } } - ExprKind::List { elts, .. } - | ExprKind::Tuple { elts, .. } => { + ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { handle_assignment_pattern( elts, value, @@ -135,16 +145,18 @@ fn main() { internal_resolver, def_list, unifier, - primitives + primitives, )?; Ok(()) } - _ => Err(format!("assignment to {:?} is not supported at {}", targets[0], targets[0].location)) + _ => Err(format!( + "assignment to {:?} is not supported at {}", + targets[0], targets[0].location + )), } } else { match &value.node { - ExprKind::List { elts, .. } - | ExprKind::Tuple { elts, .. } => { + ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { if elts.len() != targets.len() { Err(format!( "number of elements to unpack does not match (expect {}, found {}) at {}", @@ -161,13 +173,16 @@ fn main() { internal_resolver, def_list, unifier, - primitives + primitives, )?; } Ok(()) } - }, - _ => Err(format!("unpack of this expression is not supported at {}", value.location)) + } + _ => Err(format!( + "unpack of this expression is not supported at {}", + value.location + )), } } } @@ -190,9 +205,8 @@ fn main() { continue; } - let (name, def_id, ty) = composer - .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) - .unwrap(); + let (name, def_id, ty) = + composer.register_top_level(stmt, Some(resolver.clone()), "__main__".into()).unwrap(); internal_resolver.add_id_def(name, def_id); if let Some(ty) = ty { @@ -200,11 +214,7 @@ fn main() { } } - let signature = FunSignature { - args: vec![], - ret: primitive.int32, - vars: HashMap::new(), - }; + let signature = FunSignature { args: vec![], ret: primitive.int32, vars: HashMap::new() }; let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); @@ -216,17 +226,12 @@ fn main() { let instance = { let defs = top_level.definitions.read(); - let mut instance = - defs[resolver - .get_identifier_def("run".into()) - .unwrap_or_else(|_| panic!("cannot find run() entry point")).0 - ].write(); - if let TopLevelDef::Function { - instance_to_stmt, - instance_to_symbol, - .. - } = &mut *instance - { + let mut instance = defs[resolver + .get_identifier_def("run".into()) + .unwrap_or_else(|_| panic!("cannot find run() entry point")) + .0] + .write(); + if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance { instance_to_symbol.insert("".to_string(), "run".to_string()); instance_to_stmt[""].clone() } else { @@ -291,8 +296,7 @@ fn main() { passes.run_on(&main); let triple = TargetMachine::get_default_triple(); - let target = - Target::from_triple(&triple).expect("couldn't create target from target triple"); + let target = Target::from_triple(&triple).expect("couldn't create target from target triple"); let target_machine = target .create_target_machine( &triple, @@ -304,10 +308,6 @@ fn main() { ) .expect("couldn't create target machine"); target_machine - .write_to_file( - &main, - FileType::Object, - Path::new("module.o"), - ) + .write_to_file(&main, FileType::Object, Path::new("module.o")) .expect("couldn't write module to file"); } -- 2.44.1 From ede3706ca8f1c1606117b5a6507fc595157afba3 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 21 Feb 2022 18:41:42 +0800 Subject: [PATCH 4/5] type_inferencer: special case tuple index error message --- nac3core/src/typecheck/type_inferencer/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 9762fe0..4abcb9d 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1028,6 +1028,10 @@ impl<'a> Inferencer<'a> { Ok(ty) } _ => { + if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) + { + return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) + } // the index is not a constant, so value can only be a list self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; let list = self.unifier.add_ty(TypeEnum::TList { ty }); -- 2.44.1 From 3ad25c8f07d0e1859fab3573438d36cb72f7070b Mon Sep 17 00:00:00 2001 From: pca006132 Date: Tue, 22 Feb 2022 14:33:43 +0800 Subject: [PATCH 5/5] nac3core: sort error messages for determinism --- nac3core/src/codegen/mod.rs | 8 ++++---- nac3core/src/toplevel/composer.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index d83b22b..d403633 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -20,7 +20,7 @@ use inkwell::{ use itertools::Itertools; use nac3parser::ast::{Stmt, StrRef}; use parking_lot::{Condvar, Mutex}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -206,7 +206,7 @@ impl WorkerRegistry { let passes = PassManager::create(&module); pass_builder.populate_function_pass_manager(&passes); - let mut errors = Vec::new(); + let mut errors = HashSet::new(); while let Some(task) = self.receiver.recv().unwrap() { let tmp_module = context.create_module("tmp"); match gen_func(&context, generator, self, builder, tmp_module, task) { @@ -217,14 +217,14 @@ impl WorkerRegistry { } Err((old_builder, e)) => { builder = old_builder; - errors.push(e); + errors.insert(e); } } *self.task_count.lock() -= 1; self.wait_condvar.notify_all(); } if !errors.is_empty() { - panic!("Codegen error: {}", errors.iter().join("\n----------\n")); + panic!("Codegen error: {}", errors.into_iter().sorted().join("\n----------\n")); } let result = module.verify(); diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 4927507..f917e4b 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -464,7 +464,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } Ok(()) } @@ -560,7 +560,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } // second, get all ancestors @@ -591,7 +591,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } // insert the ancestors to the def list @@ -674,7 +674,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } // handle the inheritanced methods and fields @@ -727,7 +727,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } Ok(()) @@ -953,7 +953,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } Ok(()) } @@ -1555,7 +1555,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n---------\n")); + return Err(errors.into_iter().sorted().join("\n---------\n")); } for (i, signature, id) in constructors.into_iter() { @@ -1851,7 +1851,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors.iter().join("\n----------\n")); + return Err(errors.into_iter().sorted().join("\n----------\n")); } Ok(()) } -- 2.44.1