diff --git a/nac3ast/src/location.rs b/nac3ast/src/location.rs index a4416763..880b824a 100644 --- a/nac3ast/src/location.rs +++ b/nac3ast/src/location.rs @@ -3,7 +3,7 @@ use crate::ast_gen::StrRef; use std::fmt; #[derive(Clone, Copy, Debug, PartialEq)] -pub struct FileName(StrRef); +pub struct FileName(pub StrRef); impl Default for FileName { fn default() -> Self { FileName("unknown".into()) @@ -19,9 +19,9 @@ impl From for FileName { /// A location somewhere in the sourcecode. #[derive(Clone, Copy, Debug, Default, PartialEq)] pub struct Location { - row: usize, - column: usize, - file: FileName + pub row: usize, + pub column: usize, + pub file: FileName } impl fmt::Display for Location { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 5f7c6d12..31f00b9b 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -394,6 +394,48 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>( } } +pub fn final_proxy<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + target: BasicBlock<'ctx>, + block: BasicBlock<'ctx>, +) { + let (final_state, final_targets, final_paths) = ctx.outer_final.as_mut().unwrap(); + let prev = ctx.builder.get_insert_block().unwrap(); + ctx.builder.position_at_end(block); + unsafe { + ctx.builder.build_store(*final_state, target.get_address().unwrap()); + } + ctx.builder.position_at_end(prev); + final_targets.push(target); + final_paths.push(block); +} + +pub fn get_builtins<'ctx, 'a, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + symbol: &str, +) -> FunctionValue<'ctx> { + ctx.module.get_function(symbol).unwrap_or_else(|| { + let ty = match symbol { + "__artiq_raise" => ctx.ctx.void_type().fn_type( + &[ctx.get_llvm_type(generator, ctx.primitives.exception).into()], + false, + ), + "__artiq_resume" => ctx.ctx.void_type().fn_type(&[], false), + "__artiq_end_catch" => ctx.ctx.void_type().fn_type(&[], false), + _ => unimplemented!(), + }; + let fun = ctx.module.add_function(symbol, ty, None); + if symbol == "__artiq_raise" || symbol == "__artiq_resume" { + fun.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("noreturn"), 1), + ); + } + fun + }) +} + pub fn exn_constructor<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, obj: Option<(Type, ValueEnum<'ctx>)>, @@ -459,6 +501,352 @@ pub fn exn_constructor<'ctx, 'a>( Some(zelf.into()) } +pub fn gen_raise<'ctx, 'a, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + exception: Option<&BasicValueEnum<'ctx>>, + loc: Location, +) { + if let Some(exception) = exception { + unsafe { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let exception = exception.into_pointer_value(); + let file_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr"); + let filename = ctx.gen_string(generator, loc.file.0); + ctx.builder.build_store(file_ptr, filename); + let row_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr"); + ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)); + let col_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr"); + ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)); + + let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); + let name_ptr = ctx.builder.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr"); + ctx.builder.build_store(name_ptr, fun_name); + } + + let raise = get_builtins(generator, ctx, "__artiq_raise"); + let exception = *exception; + ctx.build_call_or_invoke(raise, &[exception], "raise"); + } else { + let resume = get_builtins(generator, ctx, "__artiq_resume"); + ctx.build_call_or_invoke(resume, &[], "resume"); + } + ctx.builder.build_unreachable(); +} + +pub fn gen_try<'ctx, 'a, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + target: &Stmt>, +) { + if let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node { + // if we need to generate anything related to exception, we must have personality defined + let personality_symbol = ctx.top_level.personality_symbol.as_ref().unwrap(); + let personality = ctx.module.get_function(personality_symbol).unwrap_or_else(|| { + let ty = ctx.ctx.i32_type().fn_type(&[], true); + ctx.module.add_function(personality_symbol, ty, None) + }); + let exception_type = ctx.get_llvm_type(generator, ctx.primitives.exception); + let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic); + let current_block = ctx.builder.get_insert_block().unwrap(); + let current_fun = current_block.get_parent().unwrap(); + let landingpad = ctx.ctx.append_basic_block(current_fun, "try.landingpad"); + let dispatcher = ctx.ctx.append_basic_block(current_fun, "try.dispatch"); + let mut dispatcher_end = dispatcher; + ctx.builder.position_at_end(dispatcher); + let exn = ctx.builder.build_phi(exception_type, "exn"); + ctx.builder.position_at_end(current_block); + + let mut cleanup = None; + let mut old_loop_target = None; + let mut old_return = None; + let mut old_outer_final = None; + let has_cleanup = if !finalbody.is_empty() { + let final_state = generator.gen_var_alloc(ctx, ptr_type.into()); + old_outer_final = ctx.outer_final.replace((final_state, Vec::new(), Vec::new())); + if let Some((continue_target, break_target)) = ctx.loop_target { + let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); + let continue_proxy = ctx.ctx.append_basic_block(current_fun, "try.continue"); + final_proxy(ctx, break_target, break_proxy); + final_proxy(ctx, continue_target, continue_proxy); + old_loop_target = ctx.loop_target.replace((continue_proxy, break_proxy)); + } + let return_proxy = ctx.ctx.append_basic_block(current_fun, "try.return"); + if let Some(return_target) = ctx.return_target { + final_proxy(ctx, return_target, return_proxy); + } else { + let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target"); + ctx.builder.position_at_end(return_target); + let return_value = ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret")); + ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)); + ctx.builder.position_at_end(current_block); + final_proxy(ctx, return_target, return_proxy); + } + old_return = ctx.return_target.replace(return_proxy); + cleanup = Some(ctx.ctx.append_basic_block(current_fun, "try.cleanup")); + true + } else { + ctx.outer_final.is_some() + }; + + let mut clauses = Vec::new(); + let mut found_catch_all = false; + for handler_node in handlers.iter() { + let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; + // none or Exception + if type_.is_none() || ctx.unifier.unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) { + clauses.push(None); + found_catch_all = true; + break; + } else { + let type_ = type_.as_ref().unwrap(); + let exn_name = ctx.resolver.get_type_name( + &ctx.top_level.definitions.read(), + &mut ctx.unifier, + type_.custom.unwrap(), + ); + let exn_id = ctx.resolver.get_string_id(&format!("0:{}", exn_name)); + let exn_id_global = + ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{}", exn_id)); + exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false)); + clauses.push(Some(exn_id_global.as_pointer_value().as_basic_value_enum())); + } + } + let mut all_clauses = clauses.clone(); + if let Some(old_clauses) = &ctx.outer_catch_clauses { + if !found_catch_all { + all_clauses.extend_from_slice(&old_clauses.0) + } + } + let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn)); + let old_unwind = ctx.unwind_target.replace(landingpad); + gen_block(generator, ctx, body.iter()); + if ctx.builder.get_insert_block().unwrap().get_terminator().is_none() { + gen_block(generator, ctx, orelse.iter()); + } + let body = ctx.builder.get_insert_block().unwrap(); + // reset old_clauses and old_unwind + let (all_clauses, _, _) = ctx.outer_catch_clauses.take().unwrap(); + ctx.outer_catch_clauses = old_clauses; + ctx.unwind_target = old_unwind; + ctx.return_target = old_return; + ctx.loop_target = old_loop_target; + old_loop_target = None; + + let old_unwind = if !finalbody.is_empty() { + let final_landingpad = ctx.ctx.append_basic_block(current_fun, "try.catch.final"); + ctx.builder.position_at_end(final_landingpad); + ctx.builder.build_landing_pad( + ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), + personality, + &[], + true, + "try.catch.final", + ); + ctx.builder.build_unconditional_branch(cleanup.unwrap()); + ctx.builder.position_at_end(body); + ctx.unwind_target.replace(final_landingpad) + } else { + None + }; + + // run end_catch before continue/break/return + let mut final_proxy_lambda = + |ctx: &mut CodeGenContext<'ctx, 'a>, + target: BasicBlock<'ctx>, + block: BasicBlock<'ctx>| final_proxy(ctx, target, block); + let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>, + target: BasicBlock<'ctx>, + block: BasicBlock<'ctx>| { + ctx.builder.position_at_end(block); + ctx.builder.build_unconditional_branch(target); + ctx.builder.position_at_end(body); + }; + let redirect = if ctx.outer_final.is_some() { + &mut final_proxy_lambda + as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) + } else { + &mut redirect_lambda + as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) + }; + let resume = get_builtins(generator, ctx, "__artiq_resume"); + let end_catch = get_builtins(generator, ctx, "__artiq_end_catch"); + if let Some((continue_target, break_target)) = ctx.loop_target.take() { + let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); + let continue_proxy = ctx.ctx.append_basic_block(current_fun, "try.continue"); + ctx.builder.position_at_end(break_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + ctx.builder.position_at_end(continue_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + ctx.builder.position_at_end(body); + redirect(ctx, break_target, break_proxy); + redirect(ctx, continue_target, continue_proxy); + ctx.loop_target = Some((continue_proxy, break_proxy)); + old_loop_target = Some((continue_target, break_target)); + } + let return_proxy = ctx.ctx.append_basic_block(current_fun, "try.return"); + ctx.builder.position_at_end(return_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + let return_target = ctx.return_target.take().unwrap_or_else(|| { + let doreturn = ctx.ctx.append_basic_block(current_fun, "try.doreturn"); + ctx.builder.position_at_end(doreturn); + let return_value = ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret")); + ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)); + doreturn + }); + redirect(ctx, return_target, return_proxy); + ctx.return_target = Some(return_proxy); + old_return = Some(return_target); + + let mut post_handlers = Vec::new(); + + let exnid = if !handlers.is_empty() { + ctx.builder.position_at_end(dispatcher); + unsafe { + let zero = ctx.ctx.i32_type().const_zero(); + let exnid_ptr = ctx.builder.build_gep( + exn.as_basic_value().into_pointer_value(), + &[zero, zero], + "exnidptr", + ); + Some(ctx.builder.build_load(exnid_ptr, "exnid")) + } + } else { + None + }; + + for (handler_node, exn_type) in handlers.iter().zip(clauses.iter()) { + let ExcepthandlerKind::ExceptHandler { type_, name, body } = &handler_node.node; + let handler_bb = ctx.ctx.append_basic_block(current_fun, "try.handler"); + ctx.builder.position_at_end(handler_bb); + if let Some(name) = name { + let exn_ty = ctx.get_llvm_type(generator, type_.as_ref().unwrap().custom.unwrap()); + let exn_store = generator.gen_var_alloc(ctx, exn_ty); + ctx.var_assignment.insert(*name, (exn_store, None, 0)); + ctx.builder.build_store(exn_store, exn.as_basic_value()); + } + gen_block(generator, ctx, body.iter()); + let current = ctx.builder.get_insert_block().unwrap(); + // only need to call end catch if not terminated + // otherwise, we already handled in return/break/continue/raise + if current.get_terminator().is_none() { + ctx.builder.build_call(end_catch, &[], "end_catch"); + } + post_handlers.push(current); + ctx.builder.position_at_end(dispatcher_end); + if let Some(exn_type) = exn_type { + let dispatcher_cont = + ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont"); + let actual_id = exnid.unwrap().into_int_value(); + let expected_id = ctx + .builder + .build_load(exn_type.into_pointer_value(), "expected_id") + .into_int_value(); + let result = ctx.builder.build_int_compare(EQ, actual_id, expected_id, "exncheck"); + ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont); + dispatcher_end = dispatcher_cont; + } else { + ctx.builder.build_unconditional_branch(handler_bb); + break; + } + } + + ctx.unwind_target = old_unwind; + ctx.loop_target = old_loop_target; + ctx.return_target = old_return; + + ctx.builder.position_at_end(landingpad); + let clauses: Vec<_> = if finalbody.is_empty() { &all_clauses } else { &clauses } + .iter() + .map(|v| v.unwrap_or(ptr_type.const_zero().into())) + .collect(); + let landingpad_value = ctx + .builder + .build_landing_pad( + ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), + personality, + &clauses, + has_cleanup, + "try.landingpad", + ) + .into_struct_value(); + let exn_val = ctx.builder.build_extract_value(landingpad_value, 1, "exn").unwrap(); + ctx.builder.build_unconditional_branch(dispatcher); + exn.add_incoming(&[(&exn_val, landingpad)]); + + if dispatcher_end.get_terminator().is_none() { + ctx.builder.position_at_end(dispatcher_end); + if let Some(cleanup) = cleanup { + ctx.builder.build_unconditional_branch(cleanup); + } else if let Some((_, outer_dispatcher, phi)) = ctx.outer_catch_clauses { + phi.add_incoming(&[(&exn_val, dispatcher_end)]); + ctx.builder.build_unconditional_branch(outer_dispatcher); + } else { + ctx.build_call_or_invoke(resume, &[], "resume"); + ctx.builder.build_unreachable(); + } + } + + if finalbody.is_empty() { + let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); + if body.get_terminator().is_none() { + ctx.builder.position_at_end(body); + ctx.builder.build_unconditional_branch(tail); + } + if matches!(cleanup, Some(cleanup) if cleanup.get_terminator().is_none()) { + ctx.builder.position_at_end(cleanup.unwrap()); + ctx.builder.build_unconditional_branch(tail); + } + for post_handler in post_handlers { + if post_handler.get_terminator().is_none() { + ctx.builder.position_at_end(post_handler); + ctx.builder.build_unconditional_branch(tail); + } + } + ctx.builder.position_at_end(tail); + } else { + let final_branches = ctx.outer_final.take().unwrap(); + ctx.outer_final = old_outer_final; + + // exception path + let cleanup = cleanup.unwrap(); + ctx.builder.position_at_end(cleanup); + gen_block(generator, ctx, finalbody.iter()); + if !ctx.is_terminated() { + ctx.build_call_or_invoke(resume, &[], "resume"); + ctx.builder.build_unreachable(); + } + + // normal path + let (final_state, mut final_targets, final_paths) = final_branches; + let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); + final_targets.push(tail); + let finalizer = ctx.ctx.append_basic_block(current_fun, "try.finally"); + ctx.builder.position_at_end(finalizer); + gen_block(generator, ctx, finalbody.iter()); + if !ctx.is_terminated() { + let dest = ctx.builder.build_load(final_state, "final_dest"); + ctx.builder.build_indirect_branch(dest, &final_targets); + } + for block in final_paths.iter() { + if block.get_terminator().is_none() { + ctx.builder.position_at_end(*block); + ctx.builder.build_unconditional_branch(finalizer); + } + } + for block in [body].iter().chain(post_handlers.iter()) { + if block.get_terminator().is_none() { + ctx.builder.position_at_end(*block); + unsafe { + ctx.builder.build_store(final_state, tail.get_address().unwrap()); + } + ctx.builder.build_unconditional_branch(finalizer); + } + } + ctx.builder.position_at_end(tail); + } } else { unreachable!() } @@ -531,9 +919,15 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( let value = gen_binop_expr(generator, ctx, target, op, value); generator.gen_assign(ctx, target, value); } + StmtKind::Try { .. } => gen_try(generator, ctx, stmt), + StmtKind::Raise { exc, .. } => { + let exc = exc.as_ref().map(|exc| generator.gen_expr(ctx, exc).unwrap().to_basic_value_enum(ctx, generator)); + gen_raise(generator, ctx, exc.as_ref(), stmt.location) + } _ => unimplemented!(), }; - false +} + pub fn gen_block<'ctx, 'a, 'b, G: CodeGenerator, I: Iterator>>>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index dcd3c378..6509cc40 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -109,6 +109,7 @@ pub trait SymbolResolver { fn get_symbol_location(&self, str: StrRef) -> Option; fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; + fn get_string_id(&self, s: &str) -> i32; // handle function call etc. } @@ -297,6 +298,25 @@ impl dyn SymbolResolver + Send + Sync { ) -> Result { parse_type_annotation(self, top_level_defs, unifier, primitives, expr) } + + pub fn get_type_name( + &self, + top_level_defs: &[Arc>], + unifier: &mut Unifier, + ty: Type, + ) -> String { + unifier.stringify( + ty, + &mut |id| { + if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() { + name.to_string() + } else { + unreachable!("expected class definition") + } + }, + &mut |id| format!("var{}", id), + ) + } } impl Debug for dyn SymbolResolver + Send + Sync { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 762aaed0..c3adddd4 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -148,7 +148,7 @@ impl TopLevelComposer { self.unifier.get_shared_unifier(), self.primitives_ty, )])), - personality_symbol: None, + personality_symbol: Some("__artiq_personality".into()), } } @@ -166,7 +166,7 @@ impl TopLevelComposer { ) -> Result<(StrRef, DefinitionId, Option), String> { let defined_names = &mut self.defined_names; match &ast.node { - ast::StmtKind::ClassDef { name: class_name, body, .. } => { + ast::StmtKind::ClassDef { name: class_name, bases, body, .. } => { if self.keyword_list.contains(class_name) { return Err(format!( "cannot use keyword `{}` as a class name (at {})", @@ -174,11 +174,12 @@ impl TopLevelComposer { ast.location )); } - if !defined_names.insert({ - let mut n = mod_path.clone(); - n.push_str(&class_name.to_string()); - n - }) { + let fully_qualified_class_name = if mod_path.is_empty() { + *class_name + } else { + format!("{}.{}", &mod_path, class_name).into() + }; + if !defined_names.insert(fully_qualified_class_name.into()) { return Err(format!( "duplicate definition of class `{}` (at {})", class_name, @@ -196,7 +197,7 @@ impl TopLevelComposer { Arc::new(RwLock::new(Self::make_top_level_class_def( class_def_id, resolver.clone(), - class_name, + fully_qualified_class_name, Some(constructor_ty), ))), None, @@ -218,8 +219,13 @@ impl TopLevelComposer { // we do not push anything to the def list, so we keep track of the index // and then push in the correct order after the for loop let mut class_method_index_offset = 0; - let mut contains_constructor = false; let init_id = "__init__".into(); + let exception_id = "Exception".into(); + // TODO: Fix this hack. We will generate constructor for classes that inherit + // from Exception class (directly or indirectly), but this code cannot handle + // subclass of other exception classes. + let mut contains_constructor = bases + .iter().any(|base| matches!(base.node, ast::ExprKind::Name { id, .. } if id == exception_id)); for b in body { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { if method_name == &init_id { @@ -232,21 +238,14 @@ impl TopLevelComposer { b.location )); } - let global_class_method_name = { - let mut n = mod_path.clone(); - n.push_str( - Self::make_class_method_name( - class_name.into(), - &method_name.to_string(), - ) - .as_str(), - ); - n - }; + let global_class_method_name = Self::make_class_method_name( + fully_qualified_class_name.into(), + &method_name.to_string(), + ); if !defined_names.insert(global_class_method_name.clone()) { return Err(format!( "class method `{}` defined twice (at {})", - &global_class_method_name[mod_path.len()..], + global_class_method_name, b.location )); } @@ -304,15 +303,15 @@ impl TopLevelComposer { // if self.keyword_list.contains(name) { // return Err("cannot use keyword as a top level function name".into()); // } - let global_fun_name = { - let mut n = mod_path.clone(); - n.push_str(&name.to_string()); - n + let global_fun_name = if mod_path.is_empty() { + name.to_string() + } else { + format!("{}.{}", mod_path, name) }; if !defined_names.insert(global_fun_name.clone()) { return Err(format!( "top level function `{}` defined twice (at {})", - &global_fun_name[mod_path.len()..], + global_fun_name, ast.location )); } @@ -1582,6 +1581,7 @@ impl TopLevelComposer { primitives: &self.primitives_ty, virtual_checks: &mut Vec::new(), calls: &mut calls, + in_handler: false }; let fun_body = @@ -1595,6 +1595,13 @@ impl TopLevelComposer { instance_to_symbol.insert("".into(), simple_name.to_string()); continue; } + if !decorator_list.is_empty() + && matches!(&decorator_list[0].node, + ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) + { + instance_to_symbol.insert("".into(), simple_name.to_string()); + continue; + } body } else { unreachable!("must be function def ast") @@ -1605,7 +1612,42 @@ impl TopLevelComposer { let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; - + { + // check virtuals + let defs = ctx.definitions.read(); + for (subtype, base, loc) in inferencer.virtual_checks.iter() { + let base_id = { + let base = inferencer.unifier.get_ty(*base); + if let TypeEnum::TObj { obj_id, .. } = &*base { + *obj_id + } else { + return Err(format!("Base type should be a class (at {})", loc)) + } + }; + let subtype_id = { + let ty = inferencer.unifier.get_ty(*subtype); + if let TypeEnum::TObj { obj_id, .. } = &*ty { + *obj_id + } else { + let base_repr = inferencer.unifier.default_stringify(*base); + let subtype_repr = inferencer.unifier.default_stringify(*subtype); + return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) + } + }; + let subtype_entry = defs[subtype_id.0].read(); + if let TopLevelDef::Class { ancestors, .. } = &*subtype_entry { + let m = ancestors.iter() + .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); + if m.is_none() { + let base_repr = inferencer.unifier.default_stringify(*base); + let subtype_repr = inferencer.unifier.default_stringify(*subtype); + return Err(format!("Expected a subtype of {}, but got {} (at {})", base_repr, subtype_repr, loc)) + } + } else { + unreachable!(); + } + } + } if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned { let def_ast_list = &self.definition_ast_list; let ret_str = self.unifier.stringify( diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 142b5ef8..26fa0c40 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -227,6 +227,20 @@ impl<'a> Inferencer<'a> { self.check_block(body, &mut new_defined_identifiers)?; Ok(false) } + StmtKind::Try { body, handlers, orelse, finalbody, .. } => { + self.check_block(body, &mut defined_identifiers.clone())?; + self.check_block(orelse, &mut defined_identifiers.clone())?; + for handler in handlers.iter() { + let mut defined_identifiers = defined_identifiers.clone(); + let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; + if let Some(name) = name { + defined_identifiers.insert(*name); + } + self.check_block(body, &mut defined_identifiers)?; + } + self.check_block(finalbody, defined_identifiers)?; + Ok(false) + } StmtKind::Expr { value, .. } => { self.check_expr(value, defined_identifiers)?; Ok(false) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 87bf8235..2634fa76 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -52,9 +52,10 @@ pub struct Inferencer<'a> { pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, - pub virtual_checks: &'a mut Vec<(Type, Type)>, + pub virtual_checks: &'a mut Vec<(Type, Type, Location)>, pub variable_mapping: HashMap, pub calls: &'a mut HashMap, + pub in_handler: bool, } struct NaiveFolder(); @@ -123,6 +124,56 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { }, } } + ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment } => { + let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + let outer_in_handler = self.in_handler; + let mut exception_handlers = Vec::with_capacity(handlers.len()); + self.in_handler = true; + { + let top_level_defs = self.top_level.definitions.read(); + let mut naive_folder = NaiveFolder(); + for handler in handlers.into_iter() { + let ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = handler.node; + let type_ = if let Some(type_) = type_ { + let typ = self.function_data.resolver.parse_type_annotation( + top_level_defs.as_slice(), + self.unifier, + self.primitives, + &type_ + )?; + self.virtual_checks.push((typ, self.primitives.exception, handler.location)); + if let Some(name) = name { + if !self.defined_identifiers.contains(&name) { + self.defined_identifiers.insert(name); + } + if let Some(old_typ) = self.variable_mapping.insert(name, typ) { + self.unifier.unify(old_typ, typ)?; + } + } + let mut type_ = naive_folder.fold_expr(*type_)?; + type_.custom = Some(typ); + Some(Box::new(type_)) + } else { + None + }; + let body = body.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + exception_handlers.push(Located { + location: handler.location, + node: ast::ExcepthandlerKind::ExceptHandler { type_, name, body }, + custom: None + }); + } + } + self.in_handler = outer_in_handler; + let handlers = exception_handlers; + let orelse = orelse.into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + let finalbody = finalbody .into_iter().map(|stmt| self.fold_stmt(stmt)).collect::, _>>()?; + Located { + location: node.location, + node: ast::StmtKind::Try { body, handlers, orelse, finalbody, config_comment }, + custom: None + } + } ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => { self.infer_pattern(&target)?; let target = self.fold_expr(*target)?; @@ -229,7 +280,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { _ => fold::fold_stmt(self, node)?, }; match &stmt.node { - ast::StmtKind::For { .. } => {} + ast::StmtKind::For { .. } => {}, + ast::StmtKind::Try { .. } => {}, ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; } @@ -242,6 +294,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Pass { .. } => {} + ast::StmtKind::Raise { exc, cause, .. } => { + if let Some(cause) = cause { + return report_error("raise ... from cause is not supported", cause.location); + } + if let Some(exc) = exc { + self.virtual_checks.push((exc.custom.unwrap(), self.primitives.exception, exc.location)); + } else if !self.in_handler { + return report_error("cannot reraise outside exception handlers", stmt.location); + } + } ast::StmtKind::With { items, .. } => { for item in items.iter() { let ty = item.context_expr.custom.unwrap(); @@ -537,6 +599,8 @@ impl<'a> Inferencer<'a> { top_level: self.top_level, defined_identifiers, variable_mapping, + // lambda should not be considered in exception handler + in_handler: false, }; let fun = FunSignature { args: fn_args @@ -583,6 +647,8 @@ impl<'a> Inferencer<'a> { primitives: self.primitives, calls: self.calls, defined_identifiers, + // listcomp expr should not be considered as inside an exception handler... + in_handler: false }; let generator = generators.pop().unwrap(); if generator.is_async { @@ -661,7 +727,7 @@ impl<'a> Inferencer<'a> { } else { self.unifier.get_fresh_var().0 }; - self.virtual_checks.push((arg0.custom.unwrap(), ty)); + self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location)); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); return Ok(Located { location,