From a19f1065e3d1d3852abdbe76070b0b21cc4ead9d Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 12 Dec 2023 13:38:27 +0800 Subject: [PATCH] meta: Refactor to use more let-else bindings --- nac3artiq/src/codegen.rs | 256 ++-- nac3artiq/src/lib.rs | 13 +- nac3artiq/src/symbol_resolver.rs | 223 ++- nac3artiq/src/timeline.rs | 362 ++--- nac3core/src/codegen/expr.rs | 760 +++++----- nac3core/src/codegen/mod.rs | 89 +- nac3core/src/codegen/stmt.rs | 1213 +++++++-------- nac3core/src/symbol_resolver.rs | 8 +- nac3core/src/toplevel/builtins.rs | 26 +- nac3core/src/toplevel/composer.rs | 1326 ++++++++--------- nac3core/src/toplevel/helper.rs | 81 +- nac3core/src/toplevel/type_annotation.rs | 27 +- nac3core/src/typecheck/type_inferencer/mod.rs | 66 +- nac3core/src/typecheck/typedef/mod.rs | 20 +- nac3core/src/typecheck/typedef/test.rs | 18 +- nac3standalone/src/main.rs | 9 +- 16 files changed, 2227 insertions(+), 2270 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index a5f42aa..142fe51 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -215,148 +215,148 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - if let StmtKind::With { items, body, .. } = &stmt.node { - if items.len() == 1 && items[0].optional_vars.is_none() { - let item = &items[0]; + let StmtKind::With { items, body, .. } = &stmt.node else { + unreachable!() + }; - // Behavior of parallel and sequential: - // Each function call (indirectly, can be inside a sequential block) within a parallel - // block will update the end variable to the maximum now_mu in the block. - // Each function call directly inside a parallel block will reset the timeline after - // execution. A parallel block within a sequential block (or not within any block) will - // set the timeline to the max now_mu within the block (and the outer max now_mu will also - // be updated). - // - // Implementation: We track the start and end separately. - // - If there is a start variable, it indicates that we are directly inside a - // parallel block and we have to reset the timeline after every function call. - // - If there is a end variable, it indicates that we are (indirectly) inside a - // parallel block, and we should update the max end value. - if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { - if id == &"parallel".into() || id == &"legacy_parallel".into() { - let old_start = self.start.take(); - let old_end = self.end.take(); - let old_parallel_mode = self.parallel_mode; + if items.len() == 1 && items[0].optional_vars.is_none() { + let item = &items[0]; - let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start)? - .unwrap() - .to_basic_value_enum(ctx, self, old_start.custom.unwrap())? - } else { - self.timeline.emit_now_mu(ctx) - }; + // Behavior of parallel and sequential: + // Each function call (indirectly, can be inside a sequential block) within a parallel + // block will update the end variable to the maximum now_mu in the block. + // Each function call directly inside a parallel block will reset the timeline after + // execution. A parallel block within a sequential block (or not within any block) will + // set the timeline to the max now_mu within the block (and the outer max now_mu will also + // be updated). + // + // Implementation: We track the start and end separately. + // - If there is a start variable, it indicates that we are directly inside a + // parallel block and we have to reset the timeline after every function call. + // - If there is a end variable, it indicates that we are (indirectly) inside a + // parallel block, and we should update the max end value. + if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { + if id == &"parallel".into() || id == &"legacy_parallel".into() { + let old_start = self.start.take(); + let old_end = self.end.take(); + let old_parallel_mode = self.parallel_mode; - // Emulate variable allocation, as we need to use the CodeGenContext - // HashMap to store our variable due to lifetime limitation - // Note: we should be able to store variables directly if generic - // associative type is used by limiting the lifetime of CodeGenerator to - // the LLVM Context. - // The name is guaranteed to be unique as users cannot use this as variable - // name. - self.start = old_start.clone().map_or_else( - || { - let start = format!("with-{}-start", self.name_counter).into(); - let start_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, - custom: Some(ctx.primitives.int64), - }; - let start = self - .gen_store_target(ctx, &start_expr, Some("start.addr"))? - .unwrap(); - ctx.builder.build_store(start, now); - Ok(Some(start_expr)) as Result<_, String> - }, - |v| Ok(Some(v)), - )?; - let end = format!("with-{}-end", self.name_counter).into(); - let end_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, - custom: Some(ctx.primitives.int64), - }; - let end = self - .gen_store_target(ctx, &end_expr, Some("end.addr"))? - .unwrap(); - ctx.builder.build_store(end, now); - self.end = Some(end_expr); - self.name_counter += 1; - self.parallel_mode = match id.to_string().as_str() { - "parallel" => ParallelMode::Deep, - "legacy_parallel" => ParallelMode::Legacy, - _ => unreachable!(), - }; - - self.gen_block(ctx, body.iter())?; - - let current = ctx.builder.get_insert_block().unwrap(); - - // if the current block is terminated, move before the terminator - // we want to set the timeline before reaching the terminator - // TODO: This may be unsound if there are multiple exit paths in the - // block... e.g. - // if ...: - // return - // Perhaps we can fix this by using actual with block? - let reset_position = if let Some(terminator) = current.get_terminator() { - ctx.builder.position_before(&terminator); - true - } else { - false - }; - - // set duration - let end_expr = self.end.take().unwrap(); - let end_val = self - .gen_expr(ctx, &end_expr)? + let now = if let Some(old_start) = &old_start { + self.gen_expr(ctx, old_start)? .unwrap() - .to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?; + .to_basic_value_enum(ctx, self, old_start.custom.unwrap())? + } else { + self.timeline.emit_now_mu(ctx) + }; - // inside a sequential block - if old_start.is_none() { - self.timeline.emit_at_mu(ctx, end_val); - } + // Emulate variable allocation, as we need to use the CodeGenContext + // HashMap to store our variable due to lifetime limitation + // Note: we should be able to store variables directly if generic + // associative type is used by limiting the lifetime of CodeGenerator to + // the LLVM Context. + // The name is guaranteed to be unique as users cannot use this as variable + // name. + self.start = old_start.clone().map_or_else( + || { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, + custom: Some(ctx.primitives.int64), + }; + let start = self + .gen_store_target(ctx, &start_expr, Some("start.addr"))? + .unwrap(); + ctx.builder.build_store(start, now); + Ok(Some(start_expr)) as Result<_, String> + }, + |v| Ok(Some(v)), + )?; + let end = format!("with-{}-end", self.name_counter).into(); + let end_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, + custom: Some(ctx.primitives.int64), + }; + let end = self + .gen_store_target(ctx, &end_expr, Some("end.addr"))? + .unwrap(); + ctx.builder.build_store(end, now); + self.end = Some(end_expr); + self.name_counter += 1; + self.parallel_mode = match id.to_string().as_str() { + "parallel" => ParallelMode::Deep, + "legacy_parallel" => ParallelMode::Legacy, + _ => unreachable!(), + }; - // inside a parallel block, should update the outer max now_mu - self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; + self.gen_block(ctx, body.iter())?; - self.parallel_mode = old_parallel_mode; - self.end = old_end; - self.start = old_start; + let current = ctx.builder.get_insert_block().unwrap(); - if reset_position { - ctx.builder.position_at_end(current); - } + // if the current block is terminated, move before the terminator + // we want to set the timeline before reaching the terminator + // TODO: This may be unsound if there are multiple exit paths in the + // block... e.g. + // if ...: + // return + // Perhaps we can fix this by using actual with block? + let reset_position = if let Some(terminator) = current.get_terminator() { + ctx.builder.position_before(&terminator); + true + } else { + false + }; - return Ok(()); - } else if id == &"sequential".into() { - // For deep parallel, temporarily take away start to avoid function calls in - // the block from resetting the timeline. - // This does not affect legacy parallel, as the timeline will be reset after - // this block finishes execution. - let start = self.start.take(); - self.gen_block(ctx, body.iter())?; - self.start = start; + // set duration + let end_expr = self.end.take().unwrap(); + let end_val = self + .gen_expr(ctx, &end_expr)? + .unwrap() + .to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?; - // Reset the timeline when we are exiting the sequential block - // Legacy parallel does not need this, since it will be reset after codegen - // for this statement is completed - if self.parallel_mode == ParallelMode::Deep { - self.timeline_reset_start(ctx)?; - } - - return Ok(()); + // inside a sequential block + if old_start.is_none() { + self.timeline.emit_at_mu(ctx, end_val); } + + // inside a parallel block, should update the outer max now_mu + self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; + + self.parallel_mode = old_parallel_mode; + self.end = old_end; + self.start = old_start; + + if reset_position { + ctx.builder.position_at_end(current); + } + + return Ok(()); + } else if id == &"sequential".into() { + // For deep parallel, temporarily take away start to avoid function calls in + // the block from resetting the timeline. + // This does not affect legacy parallel, as the timeline will be reset after + // this block finishes execution. + let start = self.start.take(); + self.gen_block(ctx, body.iter())?; + self.start = start; + + // Reset the timeline when we are exiting the sequential block + // Legacy parallel does not need this, since it will be reset after codegen + // for this statement is completed + if self.parallel_mode == ParallelMode::Deep { + self.timeline_reset_start(ctx)?; + } + + return Ok(()); } } - - // not parallel/sequential - gen_with(self, ctx, stmt) - } else { - unreachable!() } + + // not parallel/sequential + gen_with(self, ctx, stmt) } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 1d63dae..3e14684 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -533,14 +533,13 @@ impl Nac3 { let instance = { let defs = top_level.definitions.read(); let mut definition = defs[def_id.0].write(); - if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = - &mut *definition - { - instance_to_symbol.insert(String::new(), "__modinit__".into()); - instance_to_stmt[""].clone() - } else { + let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = + &mut *definition else { unreachable!() - } + }; + + instance_to_symbol.insert(String::new(), "__modinit__".into()); + instance_to_stmt[""].clone() }; let task = CodeGenTask { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 4b1c2e4..18e5fa6 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -311,37 +311,37 @@ impl InnerResolver { unreachable!("none cannot be typeid") } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { let def = defs[def_id.0].read(); - if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { - // do not handle type var param and concrete check here, and no subst - Ok(Ok({ - let ty = TypeEnum::TObj { - obj_id: *object_id, - params: type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { - (*id, *x) - } else { - unreachable!() - } - }) - .collect(), - fields: { - let mut res = methods - .iter() - .map(|(iden, ty, _)| (*iden, (*ty, false))) - .collect::>(); - res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); - res - }, - }; - // here also false, later instantiation use python object to check compatible - (unifier.add_ty(ty), false) - })) - } else { + let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else { // only object is supported, functions are not supported unreachable!("function type is not supported, should not be queried") - } + }; + + // do not handle type var param and concrete check here, and no subst + Ok(Ok({ + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: type_vars + .iter() + .map(|x| { + let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) else { + unreachable!() + }; + + (*id, *x) + }) + .collect(), + fields: { + let mut res = methods + .iter() + .map(|(iden, ty, _)| (*iden, (*ty, false))) + .collect::>(); + res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); + res + }, + }; + // here also false, later instantiation use python object to check compatible + (unifier.add_ty(ty), false) + })) } else if ty_ty_id == self.primitive_ids.typevar { let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap(); let (constraint_types, is_const_generic) = { @@ -652,23 +652,23 @@ impl InnerResolver { // if is `none` let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; if zelf_id == self.primitive_ids.none { - if let TypeEnum::TObj { params, .. } = - unifier.get_ty_immutable(primitives.option).as_ref() - { - let var_map = params - .iter() - .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) { - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) - } else { - unreachable!() - } - }) - .collect::>(); - return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) - } - unreachable!("must be tobj") + let ty_enum = unifier.get_ty_immutable(primitives.option); + let TypeEnum::TObj { params, .. } = ty_enum.as_ref() else { + unreachable!("must be tobj") + }; + + let var_map = params + .iter() + .map(|(id_var, ty)| { + let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) else { + unreachable!() + }; + + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + }) + .collect::>(); + return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) } let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { @@ -688,14 +688,13 @@ impl InnerResolver { let var_map = params .iter() .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, name, loc, .. } = - &*unifier.get_ty(*ty) - { - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) - } else { + let TypeEnum::TVar { id, range, name, loc, .. } = + &*unifier.get_ty(*ty) else { unreachable!() - } + }; + + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) }) .collect::>(); let mut instantiate_obj = || { @@ -900,28 +899,29 @@ impl InnerResolver { Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { - if let TypeEnum::TTuple { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() { - let tup_tys = ty.iter(); - let elements: &PyTuple = obj.downcast()?; - assert_eq!(elements.len(), tup_tys.len()); - let val: Result>, _> = - elements - .iter() - .enumerate() - .zip(tup_tys) - .map(|((i, elem), ty)| self - .get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| - super::CompileError::new_err( - format!("Error getting element {i}: {e}") - ) + let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); + let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { + unreachable!() + }; + + let tup_tys = ty.iter(); + let elements: &PyTuple = obj.downcast()?; + assert_eq!(elements.len(), tup_tys.len()); + let val: Result>, _> = + elements + .iter() + .enumerate() + .zip(tup_tys) + .map(|((i, elem), ty)| self + .get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| + super::CompileError::new_err( + format!("Error getting element {i}: {e}") ) - ).collect(); - let val = val?.unwrap(); - let val = ctx.ctx.const_struct(&val, false); - Ok(Some(val.into())) - } else { - unreachable!("must expect tuple type") - } + ) + ).collect(); + let val = val?.unwrap(); + let val = ctx.ctx.const_struct(&val, false); + Ok(Some(val.into())) } else if ty_id == self.primitive_ids.option { let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { TypeEnum::TObj { obj_id, params, .. } @@ -993,27 +993,25 @@ impl InnerResolver { // should be classes let definition = top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read(); - if let TopLevelDef::Class { fields, .. } = &*definition { - let values: Result>, _> = fields - .iter() - .map(|(name, ty, _)| { - self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) - .map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}"))) - }) - .collect(); - let values = values?; - if let Some(values) = values { - let val = ty.const_named_struct(&values); - let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) - }); - global.set_initializer(&val); - Ok(Some(global.as_pointer_value().into())) - } else { - Ok(None) - } + let TopLevelDef::Class { fields, .. } = &*definition else { unreachable!() }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty, _)| { + self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) + .map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}"))) + }) + .collect(); + let values = values?; + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) } else { - unreachable!() + Ok(None) } } } @@ -1065,27 +1063,26 @@ impl InnerResolver { impl SymbolResolver for Resolver { fn get_default_param_value(&self, expr: &ast::Expr) -> Option { - match &expr.node { - ast::ExprKind::Name { id, .. } => { - Python::with_gil(|py| -> PyResult> { - let obj: &PyAny = self.0.module.extract(py)?; - let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); - let mut sym_value = None; - for (key, val) in members { - let key: &str = key.extract()?; - if key == id.to_string() { - if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val) { - sym_value = Some(v); - } - break; - } + let ast::ExprKind::Name { id, .. } = &expr.node else { + + unreachable!("only for resolving names") + }; + + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.0.module.extract(py)?; + let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); + let mut sym_value = None; + for (key, val) in members { + let key: &str = key.extract()?; + if key == id.to_string() { + if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val) { + sym_value = Some(v); } - Ok(sym_value) - }) - .unwrap() + break; + } } - _ => unreachable!("only for resolving names"), - } + Ok(sym_value) + }).unwrap() } fn get_symbol_type( diff --git a/nac3artiq/src/timeline.rs b/nac3artiq/src/timeline.rs index de45338..41bf185 100644 --- a/nac3artiq/src/timeline.rs +++ b/nac3artiq/src/timeline.rs @@ -29,29 +29,29 @@ impl TimeFns for NowPinningTimeFns64 { let now_hiptr = ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { - ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }; + let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { + unreachable!() + }; - if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( - ctx.builder.build_load(now_hiptr, "now.hi"), - ctx.builder.build_load(now_loptr, "now.lo"), - ) { - let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); - let shifted_hi = ctx.builder.build_left_shift( - zext_hi, - i64_type.const_int(32, false), - "", - ); - let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, ""); - ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into() - } else { - unreachable!(); - } - } else { - unreachable!(); - } + let now_loptr = unsafe { + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") + }; + + let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( + ctx.builder.build_load(now_hiptr, "now.hi"), + ctx.builder.build_load(now_loptr, "now.lo"), + ) else { + unreachable!() + }; + + let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); + let shifted_hi = ctx.builder.build_left_shift( + zext_hi, + i64_type.const_int(32, false), + "", + ); + let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, ""); + ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into() } fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { @@ -59,41 +59,41 @@ impl TimeFns for NowPinningTimeFns64 { let i64_type = ctx.ctx.i64_type(); let i64_32 = i64_type.const_int(32, false); - if let BasicValueEnum::IntValue(time) = t { - let time_hi = ctx.builder.build_int_truncate( - ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), - i32_type, - "", - ); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); - let now = ctx - .module - .get_global("now") - .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder.build_bitcast( - now, - i32_type.ptr_type(AddressSpace::default()), - "now.hi.addr", - ); + let BasicValueEnum::IntValue(time) = t else { + unreachable!() + }; - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { - ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }; - ctx.builder - .build_store(now_hiptr, time_hi) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - ctx.builder - .build_store(now_loptr, time_lo) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - } else { - unreachable!(); - } - } else { - unreachable!(); - } + let time_hi = ctx.builder.build_int_truncate( + ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), + i32_type, + "", + ); + let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); + let now = ctx + .module + .get_global("now") + .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); + let now_hiptr = ctx.builder.build_bitcast( + now, + i32_type.ptr_type(AddressSpace::default()), + "now.hi.addr", + ); + + let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { + unreachable!() + }; + + let now_loptr = unsafe { + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") + }; + ctx.builder + .build_store(now_hiptr, time_hi) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); + ctx.builder + .build_store(now_loptr, time_lo) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); } fn emit_delay_mu<'ctx>( @@ -110,56 +110,56 @@ impl TimeFns for NowPinningTimeFns64 { let now_hiptr = ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { - ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }; - - if let ( - BasicValueEnum::IntValue(now_hi), - BasicValueEnum::IntValue(now_lo), - BasicValueEnum::IntValue(dt), - ) = ( - ctx.builder.build_load(now_hiptr, "now.hi"), - ctx.builder.build_load(now_loptr, "now.lo"), - dt, - ) { - let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); - let shifted_hi = ctx.builder.build_left_shift( - zext_hi, - i64_type.const_int(32, false), - "", - ); - let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, ""); - let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now"); - - let time = ctx.builder.build_int_add(now_val, dt, "time"); - let time_hi = ctx.builder.build_int_truncate( - ctx.builder.build_right_shift( - time, - i64_type.const_int(32, false), - false, - "", - ), - i32_type, - "time.hi", - ); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); - - ctx.builder - .build_store(now_hiptr, time_hi) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - ctx.builder - .build_store(now_loptr, time_lo) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - } else { - unreachable!(); - } - } else { - unreachable!(); + let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { + unreachable!() }; + + let now_loptr = unsafe { + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") + }; + + let ( + BasicValueEnum::IntValue(now_hi), + BasicValueEnum::IntValue(now_lo), + BasicValueEnum::IntValue(dt), + ) = ( + ctx.builder.build_load(now_hiptr, "now.hi"), + ctx.builder.build_load(now_loptr, "now.lo"), + dt, + ) else { + unreachable!() + }; + + let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); + let shifted_hi = ctx.builder.build_left_shift( + zext_hi, + i64_type.const_int(32, false), + "", + ); + let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, ""); + let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now"); + + let time = ctx.builder.build_int_add(now_val, dt, "time"); + let time_hi = ctx.builder.build_int_truncate( + ctx.builder.build_right_shift( + time, + i64_type.const_int(32, false), + false, + "", + ), + i32_type, + "time.hi", + ); + let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); + + ctx.builder + .build_store(now_hiptr, time_hi) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); + ctx.builder + .build_store(now_loptr, time_lo) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); } } @@ -176,14 +176,14 @@ impl TimeFns for NowPinningTimeFns { .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); - if let BasicValueEnum::IntValue(now_raw) = now_raw { - let i64_32 = i64_type.const_int(32, false); - let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); - let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); - ctx.builder.build_or(now_lo, now_hi, "now_mu").into() - } else { - unreachable!(); - } + let BasicValueEnum::IntValue(now_raw) = now_raw else { + unreachable!() + }; + + let i64_32 = i64_type.const_int(32, false); + let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); + let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); + ctx.builder.build_or(now_lo, now_hi, "now_mu").into() } fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { @@ -191,41 +191,41 @@ impl TimeFns for NowPinningTimeFns { let i64_type = ctx.ctx.i64_type(); let i64_32 = i64_type.const_int(32, false); - if let BasicValueEnum::IntValue(time) = t { - let time_hi = ctx.builder.build_int_truncate( - ctx.builder.build_right_shift(time, i64_32, false, ""), - i32_type, - "time.hi", - ); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); - let now = ctx - .module - .get_global("now") - .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder.build_bitcast( - now, - i32_type.ptr_type(AddressSpace::default()), - "now.hi.addr", - ); + let BasicValueEnum::IntValue(time) = t else { + unreachable!() + }; - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { - ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") - }; - ctx.builder - .build_store(now_hiptr, time_hi) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - ctx.builder - .build_store(now_loptr, time_lo) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - } else { - unreachable!(); - } - } else { - unreachable!(); - } + let time_hi = ctx.builder.build_int_truncate( + ctx.builder.build_right_shift(time, i64_32, false, ""), + i32_type, + "time.hi", + ); + let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); + let now = ctx + .module + .get_global("now") + .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); + let now_hiptr = ctx.builder.build_bitcast( + now, + i32_type.ptr_type(AddressSpace::default()), + "now.hi.addr", + ); + + let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { + unreachable!() + }; + + let now_loptr = unsafe { + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") + }; + ctx.builder + .build_store(now_hiptr, time_hi) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); + ctx.builder + .build_store(now_loptr, time_lo) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); } fn emit_delay_mu<'ctx>( @@ -242,41 +242,41 @@ impl TimeFns for NowPinningTimeFns { .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); let now_raw = ctx.builder.build_load(now.as_pointer_value(), ""); - if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) { - let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); - let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); - let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val"); - let time = ctx.builder.build_int_add(now_val, dt, "time"); - let time_hi = ctx.builder.build_int_truncate( - ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), - i32_type, - "now_trunc", - ); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); - let now_hiptr = ctx.builder.build_bitcast( - now, - i32_type.ptr_type(AddressSpace::default()), - "now.hi.addr", - ); + let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) else { + unreachable!() + }; - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { - ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") - }; - ctx.builder - .build_store(now_hiptr, time_hi) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - ctx.builder - .build_store(now_loptr, time_lo) - .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) - .unwrap(); - } else { - unreachable!(); - } - } else { - unreachable!(); - } + let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); + let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); + let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val"); + let time = ctx.builder.build_int_add(now_val, dt, "time"); + let time_hi = ctx.builder.build_int_truncate( + ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), + i32_type, + "now_trunc", + ); + let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); + let now_hiptr = ctx.builder.build_bitcast( + now, + i32_type.ptr_type(AddressSpace::default()), + "now.hi.addr", + ); + + let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { + unreachable!() + }; + + let now_loptr = unsafe { + ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") + }; + ctx.builder + .build_store(now_hiptr, time_hi) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); + ctx.builder + .build_store(now_loptr, time_lo) + .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) + .unwrap(); } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index fc71c8e..7f5881a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -39,11 +39,10 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { - params.clone() - } else { + let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() - } + }; + params.clone() }) .unwrap_or_default(); vars.extend(fun_vars.iter()); @@ -224,7 +223,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { { self.ctx.i64_type() } else { - unreachable!(); + unreachable!() }; Some(ty.const_int(*val as u64, false).into()) } @@ -599,28 +598,27 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( def: &TopLevelDef, params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result, String> { - match def { - TopLevelDef::Class { methods, .. } => { - // TODO: what about other fields that require alloca? - let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2); - let ty = ctx.get_llvm_type(generator, signature.ret).into_pointer_type(); - let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); - let zelf: BasicValueEnum<'ctx> = ctx.builder.build_alloca(zelf_ty, "alloca").into(); - // call `__init__` if there is one - if let Some(fun_id) = fun_id { - let mut sign = signature.clone(); - sign.ret = ctx.primitives.none; - generator.gen_call( - ctx, - Some((signature.ret, zelf.into())), - (&sign, fun_id), - params, - )?; - } - Ok(zelf) - } - TopLevelDef::Function { .. } => unreachable!(), + let TopLevelDef::Class { methods, .. } = def else { + unreachable!() + }; + + // TODO: what about other fields that require alloca? + let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2); + let ty = ctx.get_llvm_type(generator, signature.ret).into_pointer_type(); + let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); + let zelf: BasicValueEnum<'ctx> = ctx.builder.build_alloca(zelf_ty, "alloca").into(); + // call `__init__` if there is one + if let Some(fun_id) = fun_id { + let mut sign = signature.clone(); + sign.ret = ctx.primitives.none; + generator.gen_call( + ctx, + Some((signature.ret, zelf.into())), + (&sign, fun_id), + params, + )?; } + Ok(zelf) } /// See [`CodeGenerator::gen_func_instance`]. @@ -630,74 +628,71 @@ pub fn gen_func_instance<'ctx>( fun: (&FunSignature, &mut TopLevelDef, String), id: usize, ) -> Result { - if let ( + let ( sign, TopLevelDef::Function { name, instance_to_symbol, instance_to_stmt, var_id, resolver, .. }, key, - ) = fun - { - if let Some(sym) = instance_to_symbol.get(&key) { - return Ok(sym.clone()); - } - let symbol = format!("{}.{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key, symbol.clone()); - let mut filter = var_id.clone(); - if let Some((obj_ty, _)) = &obj { - if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty(*obj_ty) { - filter.extend(params.keys()); - } - } - let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(&filter)); - let instance = instance_to_stmt.get(&key).unwrap(); + ) = fun else { unreachable!() }; - let mut store = ConcreteTypeStore::new(); - let mut cache = HashMap::new(); - - let subst = sign - .vars - .iter() - .map(|(id, ty)| { - ( - *instance.subst.get(id).unwrap(), - store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, *ty, &mut cache), - ) - }) - .collect(); - - let mut signature = - store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); - - if let Some(obj) = &obj { - let zelf = - store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); - if let ConcreteTypeEnum::TFunc { args, .. } = &mut signature { - args.insert( - 0, - ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, - ); - } else { - unreachable!() - } - } - let signature = store.add_cty(signature); - - ctx.registry.add_task(CodeGenTask { - symbol_name: symbol.clone(), - body: instance.body.clone(), - resolver: resolver.as_ref().unwrap().clone(), - calls: instance.calls.clone(), - subst, - signature, - store, - unifier_index: instance.unifier_id, - id, - }); - Ok(symbol) - } else { - unreachable!() + if let Some(sym) = instance_to_symbol.get(&key) { + return Ok(sym.clone()); } + let symbol = format!("{}.{}", name, instance_to_symbol.len()); + instance_to_symbol.insert(key, symbol.clone()); + let mut filter = var_id.clone(); + if let Some((obj_ty, _)) = &obj { + if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty(*obj_ty) { + filter.extend(params.keys()); + } + } + let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(&filter)); + let instance = instance_to_stmt.get(&key).unwrap(); + + let mut store = ConcreteTypeStore::new(); + let mut cache = HashMap::new(); + + let subst = sign + .vars + .iter() + .map(|(id, ty)| { + ( + *instance.subst.get(id).unwrap(), + store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, *ty, &mut cache), + ) + }) + .collect(); + + let mut signature = + store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); + + if let Some(obj) = &obj { + let zelf = + store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); + let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { + unreachable!() + }; + + args.insert( + 0, + ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, + ); + } + let signature = store.add_cty(signature); + + ctx.registry.add_task(CodeGenTask { + symbol_name: symbol.clone(), + body: instance.body.clone(), + resolver: resolver.as_ref().unwrap().clone(), + calls: instance.calls.clone(), + subst, + signature, + store, + unifier_index: instance.unifier_id, + id, + }); + Ok(symbol) } /// See [`CodeGenerator::gen_call`]. @@ -946,172 +941,172 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, expr: &Expr>, ) -> Result>, String> { - if let ExprKind::ListComp { elt, generators } = &expr.node { - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let ExprKind::ListComp { elt, generators } = &expr.node else { + unreachable!() + }; - let init_bb = ctx.ctx.append_basic_block(current, "listcomp.init"); - let test_bb = ctx.ctx.append_basic_block(current, "listcomp.test"); - let body_bb = ctx.ctx.append_basic_block(current, "listcomp.body"); - let cont_bb = ctx.ctx.append_basic_block(current, "listcomp.cont"); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - ctx.builder.build_unconditional_branch(init_bb); + let init_bb = ctx.ctx.append_basic_block(current, "listcomp.init"); + let test_bb = ctx.ctx.append_basic_block(current, "listcomp.test"); + let body_bb = ctx.ctx.append_basic_block(current, "listcomp.body"); + let cont_bb = ctx.ctx.append_basic_block(current, "listcomp.cont"); - ctx.builder.position_at_end(init_bb); + ctx.builder.build_unconditional_branch(init_bb); - let Comprehension { target, iter, ifs, .. } = &generators[0]; - let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { - v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? - } else { - for bb in [test_bb, body_bb, cont_bb] { - ctx.builder.position_at_end(bb); - ctx.builder.build_unreachable(); - } + ctx.builder.position_at_end(init_bb); - return Ok(None) - }; - let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); - let zero_size_t = size_t.const_zero(); - let zero_32 = int32.const_zero(); - - let index = generator.gen_var_alloc(ctx, size_t.into(), Some("index.addr"))?; - ctx.builder.build_store(index, zero_size_t); - - let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); - let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); - let list; - let list_content; - - if is_range { - let iter_val = iter_val.into_pointer_value(); - let (start, stop, step) = destructure_range(ctx, iter_val); - let diff = ctx.builder.build_int_sub(stop, start, "diff"); - // add 1 to the length as the value is rounded to zero - // the length may be 1 more than the actual length if the division is exact, but the - // length is a upper bound only anyway so it does not matter. - let length = ctx.builder.build_int_signed_div(diff, step, "div"); - let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); - // in case length is non-positive - let is_valid = - ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check"); - - let list_alloc_size = ctx.builder.build_select( - is_valid, - ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len"), - zero_size_t, - "listcomp.alloc_size" - ); - list = allocate_list( - generator, - ctx, - elem_ty, - list_alloc_size.into_int_value(), - Some("listcomp.addr") - ); - list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr")) - .into_pointer_value(); - - let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); - ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); - - ctx.builder.build_conditional_branch( - gen_in_range_check(ctx, start, stop, step), - test_bb, - cont_bb, - ); - - ctx.builder.position_at_end(test_bb); - // add and test - let tmp = ctx.builder.build_int_add( - ctx.builder.build_load(i, "i").into_int_value(), - step, - "start_loop", - ); - ctx.builder.build_store(i, tmp); - ctx.builder.build_conditional_branch( - gen_in_range_check(ctx, tmp, stop, step), - body_bb, - cont_bb, - ); - - ctx.builder.position_at_end(body_bb); - } else { - let length = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero_size_t, int32.const_int(1, false)], - Some("length"), - ) - .into_int_value(); - list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp")); - list_content = - ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value(); - let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; - // counter = -1 - ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); - ctx.builder.build_unconditional_branch(test_bb); - - ctx.builder.position_at_end(test_bb); - let tmp = ctx.builder.build_load(counter, "i").into_int_value(); - let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc"); - ctx.builder.build_store(counter, tmp); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp"); - ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb); - - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load(iter_val.into_pointer_value(), &[zero_size_t, zero_32], Some("arr.addr")) - .into_pointer_value(); - let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); - generator.gen_assign(ctx, target, val.into())?; + let Comprehension { target, iter, ifs, .. } = &generators[0]; + let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { + v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? + } else { + for bb in [test_bb, body_bb, cont_bb] { + ctx.builder.position_at_end(bb); + ctx.builder.build_unreachable(); } - // Emits the content of `cont_bb` - let emit_cont_bb = |ctx: &CodeGenContext| { - ctx.builder.position_at_end(cont_bb); - let len_ptr = unsafe { - ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length") - }; - ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); + return Ok(None) + }; + let int32 = ctx.ctx.i32_type(); + let size_t = generator.get_size_type(ctx.ctx); + let zero_size_t = size_t.const_zero(); + let zero_32 = int32.const_zero(); + + let index = generator.gen_var_alloc(ctx, size_t.into(), Some("index.addr"))?; + ctx.builder.build_store(index, zero_size_t); + + let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); + let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); + let list; + let list_content; + + if is_range { + let iter_val = iter_val.into_pointer_value(); + let (start, stop, step) = destructure_range(ctx, iter_val); + let diff = ctx.builder.build_int_sub(stop, start, "diff"); + // add 1 to the length as the value is rounded to zero + // the length may be 1 more than the actual length if the division is exact, but the + // length is a upper bound only anyway so it does not matter. + let length = ctx.builder.build_int_signed_div(diff, step, "div"); + let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); + // in case length is non-positive + let is_valid = + ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check"); + + let list_alloc_size = ctx.builder.build_select( + is_valid, + ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len"), + zero_size_t, + "listcomp.alloc_size" + ); + list = allocate_list( + generator, + ctx, + elem_ty, + list_alloc_size.into_int_value(), + Some("listcomp.addr") + ); + list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr")) + .into_pointer_value(); + + let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); + ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); + + ctx.builder.build_conditional_branch( + gen_in_range_check(ctx, start, stop, step), + test_bb, + cont_bb, + ); + + ctx.builder.position_at_end(test_bb); + // add and test + let tmp = ctx.builder.build_int_add( + ctx.builder.build_load(i, "i").into_int_value(), + step, + "start_loop", + ); + ctx.builder.build_store(i, tmp); + ctx.builder.build_conditional_branch( + gen_in_range_check(ctx, tmp, stop, step), + body_bb, + cont_bb, + ); + + ctx.builder.position_at_end(body_bb); + } else { + let length = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, int32.const_int(1, false)], + Some("length"), + ) + .into_int_value(); + list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp")); + list_content = + ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value(); + let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; + // counter = -1 + ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); + ctx.builder.build_unconditional_branch(test_bb); + + ctx.builder.position_at_end(test_bb); + let tmp = ctx.builder.build_load(counter, "i").into_int_value(); + let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc"); + ctx.builder.build_store(counter, tmp); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp"); + ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb); + + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load(iter_val.into_pointer_value(), &[zero_size_t, zero_32], Some("arr.addr")) + .into_pointer_value(); + let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); + generator.gen_assign(ctx, target, val.into())?; + } + + // Emits the content of `cont_bb` + let emit_cont_bb = |ctx: &CodeGenContext| { + ctx.builder.position_at_end(cont_bb); + let len_ptr = unsafe { + ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length") }; + ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); + }; - for cond in ifs { - let result = if let Some(v) = generator.gen_expr(ctx, cond)? { - v.to_basic_value_enum(ctx, generator, cond.custom.unwrap())?.into_int_value() - } else { - // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the - // no element matches the predicate - emit_cont_bb(ctx); - - return Ok(None) - }; - let result = generator.bool_to_i1(ctx, result); - let succ = ctx.ctx.append_basic_block(current, "then"); - ctx.builder.build_conditional_branch(result, succ, test_bb); - - ctx.builder.position_at_end(succ); - } - - let Some(elem) = generator.gen_expr(ctx, elt)? else { - // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents + for cond in ifs { + let result = if let Some(v) = generator.gen_expr(ctx, cond)? { + v.to_basic_value_enum(ctx, generator, cond.custom.unwrap())?.into_int_value() + } else { + // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the + // no element matches the predicate emit_cont_bb(ctx); return Ok(None) }; - let i = ctx.builder.build_load(index, "i").into_int_value(); - let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }; - let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?; - ctx.builder.build_store(elem_ptr, val); - ctx.builder - .build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc")); - ctx.builder.build_unconditional_branch(test_bb); + let result = generator.bool_to_i1(ctx, result); + let succ = ctx.ctx.append_basic_block(current, "then"); + ctx.builder.build_conditional_branch(result, succ, test_bb); + ctx.builder.position_at_end(succ); + } + + let Some(elem) = generator.gen_expr(ctx, elt)? else { + // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents emit_cont_bb(ctx); - Ok(Some(list.into())) - } else { - unreachable!() - } + return Ok(None) + }; + let i = ctx.builder.build_load(index, "i").into_int_value(); + let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }; + let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?; + ctx.builder.build_store(elem_ptr, val); + ctx.builder + .build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc")); + ctx.builder.build_unconditional_branch(test_bb); + + emit_cont_bb(ctx); + + Ok(Some(list.into())) } /// Generates LLVM IR for a [binary operator expression][expr]. @@ -1170,9 +1165,11 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( .unwrap_left(); Ok(Some(res.into())) } else { - let (op_name, id) = if let TypeEnum::TObj { fields, obj_id, .. } = - ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() - { + let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap()); + let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { + unreachable!("must be tobj") + }; + let (op_name, id) = { let (binop_name, binop_assign_name) = ( binop_name(op).into(), binop_assign_name(op).into() @@ -1183,34 +1180,33 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( } else { (binop_name, *obj_id) } - } else { - unreachable!("must be tobj") }; + let signature = match ctx.calls.get(&loc.into()) { Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), None => { - if let TypeEnum::TObj { fields, .. } = - ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() - { - let fn_ty = fields.get(&op_name).unwrap().0; - if let TypeEnum::TFunc(sig) = ctx.unifier.get_ty_immutable(fn_ty).as_ref() { - sig.clone() - } else { - unreachable!("must be func sig") - } - } else { + let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap()); + let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else { unreachable!("must be tobj") - } + }; + + let fn_ty = fields.get(&op_name).unwrap().0; + let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty); + let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { + unreachable!() + }; + + sig.clone() }, }; let fun_id = { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - if let TopLevelDef::Class { methods, .. } = &*obj_def { - methods.iter().find(|method| method.0 == op_name).unwrap().2 - } else { + let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() - } + }; + + methods.iter().find(|method| method.0 == op_name).unwrap().2 }; generator .gen_call( @@ -1290,11 +1286,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } let ty = if elements.is_empty() { - if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) { - ctx.get_llvm_type(generator, *ty) - } else { + let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) else { unreachable!() - } + }; + + ctx.get_llvm_type(generator, *ty) } else { elements[0].get_type() }; @@ -1636,11 +1632,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.unifier.get_call_signature(*call).unwrap() } else { let ty = func.custom.unwrap(); - if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) { - sign.clone() - } else { + let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { unreachable!() - } + }; + + sign.clone() }; let func = func.as_ref(); match &func.node { @@ -1669,11 +1665,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let fun_id = { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - if let TopLevelDef::Class { methods, .. } = &*obj_def { - methods.iter().find(|method| method.0 == *attr).unwrap().2 - } else { + let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() - } + }; + + methods.iter().find(|method| method.0 == *attr).unwrap().2 }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant @@ -1755,125 +1751,127 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::Subscript { value, slice, .. } => { - if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(value.custom.unwrap()) { - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() - } else { - return Ok(None) - }; - let ty = ctx.get_llvm_type(generator, *ty); - let arr_ptr = ctx.build_gep_and_load(v, &[zero, zero], Some("arr.addr")) - .into_pointer_value(); - if let ExprKind::Slice { lower, upper, step } = &slice.node { - let one = int32.const_int(1, false); - let Some((start, end, step)) = - handle_slice_indices(lower, upper, step, ctx, generator, v)? else { - return Ok(None) - }; - let length = calculate_len_for_slice_range( - generator, - ctx, - start, - ctx.builder - .build_select( - ctx.builder.build_int_compare( - IntPredicate::SLT, - step, - zero, - "is_neg", - ), - ctx.builder.build_int_sub(end, one, "e_min_one"), - ctx.builder.build_int_add(end, one, "e_add_one"), - "final_e", - ) - .into_int_value(), - step, - ); - let res_array_ret = allocate_list(generator, ctx, ty, length, Some("ret")); - let Some(res_ind) = - handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)? else { - return Ok(None) - }; - list_slice_assignment( - generator, - ctx, - ty, - res_array_ret, - res_ind, - v, - (start, end, step), - ); - res_array_ret.into() - } else { - let len = ctx - .build_gep_and_load(v, &[zero, int32.const_int(1, false)], Some("len")) - .into_int_value(); - let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() + match &*ctx.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { ty } => { + let v = if let Some(v) = generator.gen_expr(ctx, value)? { + v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() } else { return Ok(None) }; - let raw_index = ctx.builder.build_int_s_extend( - raw_index, - generator.get_size_type(ctx.ctx), - "sext", - ); - // handle negative index - let is_negative = ctx.builder.build_int_compare( - IntPredicate::SLT, - raw_index, - generator.get_size_type(ctx.ctx).const_zero(), - "is_neg", - ); - let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); - let index = ctx - .builder - .build_select(is_negative, adjusted, raw_index, "index") - .into_int_value(); - // unsigned less than is enough, because negative index after adjustment is - // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder.build_int_compare( - IntPredicate::ULT, - index, - len, - "inbound", - ); - ctx.make_assert( - generator, - bound_check, - "0:IndexError", - "index {0} out of bounds 0:{1}", - [Some(raw_index), Some(len), None], - expr.location, - ); - ctx.build_gep_and_load(arr_ptr, &[index], None).into() - } - } else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { - let index: u32 = - if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node { - (*v).try_into().unwrap() + let ty = ctx.get_llvm_type(generator, *ty); + let arr_ptr = ctx.build_gep_and_load(v, &[zero, zero], Some("arr.addr")) + .into_pointer_value(); + if let ExprKind::Slice { lower, upper, step } = &slice.node { + let one = int32.const_int(1, false); + let Some((start, end, step)) = + handle_slice_indices(lower, upper, step, ctx, generator, v)? else { + return Ok(None) + }; + let length = calculate_len_for_slice_range( + generator, + ctx, + start, + ctx.builder + .build_select( + ctx.builder.build_int_compare( + IntPredicate::SLT, + step, + zero, + "is_neg", + ), + ctx.builder.build_int_sub(end, one, "e_min_one"), + ctx.builder.build_int_add(end, one, "e_add_one"), + "final_e", + ) + .into_int_value(), + step, + ); + let res_array_ret = allocate_list(generator, ctx, ty, length, Some("ret")); + let Some(res_ind) = + handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)? else { + return Ok(None) + }; + list_slice_assignment( + generator, + ctx, + ty, + res_array_ret, + res_ind, + v, + (start, end, step), + ); + res_array_ret.into() } else { - unreachable!("tuple subscript must be const int after type check"); - }; - match generator.gen_expr(ctx, value)? { - Some(ValueEnum::Dynamic(v)) => { - let v = v.into_struct_value(); - ctx.builder.build_extract_value(v, index, "tup_elem").unwrap().into() - } - Some(ValueEnum::Static(v)) => { - if let Some(v) = v.get_tuple_element(index) { - v + let len = ctx + .build_gep_and_load(v, &[zero, int32.const_int(1, false)], Some("len")) + .into_int_value(); + let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { + v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() } else { - let tup = v - .to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_struct_value(); - ctx.builder.build_extract_value(tup, index, "tup_elem").unwrap().into() - } + return Ok(None) + }; + let raw_index = ctx.builder.build_int_s_extend( + raw_index, + generator.get_size_type(ctx.ctx), + "sext", + ); + // handle negative index + let is_negative = ctx.builder.build_int_compare( + IntPredicate::SLT, + raw_index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ); + let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); + let index = ctx + .builder + .build_select(is_negative, adjusted, raw_index, "index") + .into_int_value(); + // unsigned less than is enough, because negative index after adjustment is + // bigger than the length (for unsigned cmp) + let bound_check = ctx.builder.build_int_compare( + IntPredicate::ULT, + index, + len, + "inbound", + ); + ctx.make_assert( + generator, + bound_check, + "0:IndexError", + "index {0} out of bounds 0:{1}", + [Some(raw_index), Some(len), None], + expr.location, + ); + ctx.build_gep_and_load(arr_ptr, &[index], None).into() } - None => return Ok(None), } - } else { - unreachable!("should not be other subscriptable types after type check"); + TypeEnum::TTuple { .. } => { + let index: u32 = + if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node { + (*v).try_into().unwrap() + } else { + unreachable!("tuple subscript must be const int after type check"); + }; + match generator.gen_expr(ctx, value)? { + Some(ValueEnum::Dynamic(v)) => { + let v = v.into_struct_value(); + ctx.builder.build_extract_value(v, index, "tup_elem").unwrap().into() + } + Some(ValueEnum::Static(v)) => { + if let Some(v) = v.get_tuple_element(index) { + v + } else { + let tup = v + .to_basic_value_enum(ctx, generator, value.custom.unwrap())? + .into_struct_value(); + ctx.builder.build_extract_value(tup, index, "tup_elem").unwrap().into() + } + } + None => return Ok(None), + } + } + _ => unreachable!("should not be other subscriptable types after type check"), } }, ExprKind::ListComp { .. } => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 06c0fc8..0917d8c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -451,40 +451,38 @@ fn get_llvm_type<'ctx>( // a struct with fields in the order of declaration let top_level_defs = top_level.definitions.read(); let definition = top_level_defs.get(obj_id.0).unwrap(); - let ty = if let TopLevelDef::Class { fields: fields_list, .. } = - &*definition.read() - { - let name = unifier.stringify(ty); - if let Some(t) = module.get_struct_type(&name) { - t.ptr_type(AddressSpace::default()).into() - } else { - let struct_type = ctx.opaque_struct_type(&name); - type_cache.insert( - unifier.get_representative(ty), - struct_type.ptr_type(AddressSpace::default()).into() - ); - let fields = fields_list - .iter() - .map(|f| { - get_llvm_type( - ctx, - module, - generator, - unifier, - top_level, - type_cache, - primitives, - fields[&f.0].0, - ) - }) - .collect_vec(); - struct_type.set_body(&fields, false); - struct_type.ptr_type(AddressSpace::default()).into() - } - } else { + let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read() else { unreachable!() }; - return ty; + + let name = unifier.stringify(ty); + let ty = if let Some(t) = module.get_struct_type(&name) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into() + ); + let fields = fields_list + .iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + primitives, + fields[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty } TTuple { ty } => { // a struct with fields in the order present in the tuple @@ -661,22 +659,21 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte // NOTE: special handling of option cannot use this type cache since it contains type var, // handled inside get_llvm_type instead - let (args, ret) = if let ConcreteTypeEnum::TFunc { args, ret, .. } = - task.store.get(task.signature) - { - ( - args.iter() - .map(|arg| FuncArg { - name: arg.name, - ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), - default_value: arg.default_value.clone(), - }) - .collect_vec(), - task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), - ) - } else { + let ConcreteTypeEnum::TFunc { args, ret, .. } = + task.store.get(task.signature) else { unreachable!() }; + + let (args, ret) = ( + args.iter() + .map(|arg| FuncArg { + name: arg.name, + ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), + default_value: arg.default_value.clone(), + }) + .collect_vec(), + task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), + ); let ret_type = if unifier.unioned(ret, primitives.none) { None } else { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 8f76f41..0e0449e 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -171,49 +171,47 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ) -> Result<(), String> { match &target.node { ExprKind::Tuple { elts, .. } => { - if let BasicValueEnum::StructValue(v) = - value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? - { - for (i, elt) in elts.iter().enumerate() { - let v = ctx - .builder - .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") - .unwrap(); - generator.gen_assign(ctx, elt, v.into())?; - } - } else { + let BasicValueEnum::StructValue(v) = + value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? else { unreachable!() + }; + + for (i, elt) in elts.iter().enumerate() { + let v = ctx + .builder + .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") + .unwrap(); + generator.gen_assign(ctx, elt, v.into())?; } } ExprKind::Subscript { value: ls, slice, .. } if matches!(&slice.node, ExprKind::Slice { .. }) => { - if let ExprKind::Slice { lower, upper, step } = &slice.node { - let ls = generator - .gen_expr(ctx, ls)? - .unwrap() - .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? - .into_pointer_value(); - let Some((start, end, step)) = - handle_slice_indices(lower, upper, step, ctx, generator, ls)? else { - return Ok(()) - }; - let value = value - .to_basic_value_enum(ctx, generator, target.custom.unwrap())? - .into_pointer_value(); - let ty = - if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) { - ctx.get_llvm_type(generator, *ty) - } else { - unreachable!() - }; - let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { - return Ok(()) - }; - list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); - } else { + let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() - } + }; + + let ls = generator + .gen_expr(ctx, ls)? + .unwrap() + .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? + .into_pointer_value(); + let Some((start, end, step)) = + handle_slice_indices(lower, upper, step, ctx, generator, ls)? else { + return Ok(()) + }; + let value = value + .to_basic_value_enum(ctx, generator, target.custom.unwrap())? + .into_pointer_value(); + let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else { + unreachable!() + }; + + let ty = ctx.get_llvm_type(generator, *ty); + let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { + return Ok(()) + }; + list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); } _ => { let name = if let ExprKind::Name { id, .. } = &target.node { @@ -245,158 +243,159 @@ pub fn gen_for( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - if let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node { - // var_assignment static values may be changed in another branch - // if so, remove the static value as it may not be correct in this branch - let var_assignment = ctx.var_assignment.clone(); - - let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); - let zero = int32.const_zero(); - let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); - let body_bb = ctx.ctx.append_basic_block(current, "for.body"); - let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); - // if there is no orelse, we just go to cont_bb - let orelse_bb = if orelse.is_empty() { - cont_bb - } else { - ctx.ctx.append_basic_block(current, "for.orelse") - }; - - // Whether the iterable is a range() expression - let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); - - // The BB containing the increment expression - let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); - // The BB containing the loop condition check - let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); - - // store loop bb information and restore it later - let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); - - let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { - v.to_basic_value_enum( - ctx, - generator, - iter.custom.unwrap(), - )? - } else { - return Ok(()) - }; - if is_iterable_range_expr { - let iter_val = iter_val.into_pointer_value(); - // Internal variable for loop; Cannot be assigned - let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; - // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed - let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? else { - unreachable!() - }; - let (start, stop, step) = destructure_range(ctx, iter_val); - - ctx.builder.build_store(i, start); - - // Check "If step is zero, ValueError is raised." - let rangenez = ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), ""); - ctx.make_assert( - generator, - rangenez, - "ValueError", - "range() arg 3 must not be zero", - [None, None, None], - ctx.current_loc - ); - ctx.builder.build_unconditional_branch(cond_bb); - - { - ctx.builder.position_at_end(cond_bb); - ctx.builder.build_conditional_branch( - gen_in_range_check( - ctx, - ctx.builder.build_load(i, "").into_int_value(), - stop, - step - ), - body_bb, - orelse_bb, - ); - } - - ctx.builder.position_at_end(incr_bb); - let next_i = ctx.builder.build_int_add( - ctx.builder.build_load(i, "").into_int_value(), - step, - "inc", - ); - ctx.builder.build_store(i, next_i); - ctx.builder.build_unconditional_branch(cond_bb); - - ctx.builder.position_at_end(body_bb); - ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value()); - generator.gen_block(ctx, body.iter())?; - } else { - let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; - ctx.builder.build_store(index_addr, size_t.const_zero()); - let len = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero, int32.const_int(1, false)], - Some("len") - ) - .into_int_value(); - ctx.builder.build_unconditional_branch(cond_bb); - - ctx.builder.position_at_end(cond_bb); - let index = ctx.builder.build_load(index_addr, "for.index").into_int_value(); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond"); - ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb); - - ctx.builder.position_at_end(incr_bb); - let index = ctx.builder.build_load(index_addr, "").into_int_value(); - let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc"); - ctx.builder.build_store(index_addr, inc); - ctx.builder.build_unconditional_branch(cond_bb); - - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) - .into_pointer_value(); - let index = ctx.builder.build_load(index_addr, "for.index").into_int_value(); - let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); - generator.gen_assign(ctx, target, val.into())?; - generator.gen_block(ctx, body.iter())?; - } - - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } - } - - if !ctx.is_terminated() { - ctx.builder.build_unconditional_branch(incr_bb); - } - - if !orelse.is_empty() { - ctx.builder.position_at_end(orelse_bb); - generator.gen_block(ctx, orelse.iter())?; - if !ctx.is_terminated() { - ctx.builder.build_unconditional_branch(cont_bb); - } - } - - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } - } - - ctx.builder.position_at_end(cont_bb); - ctx.loop_target = loop_bb; - } else { + let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() + }; + + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + + let int32 = ctx.ctx.i32_type(); + let size_t = generator.get_size_type(ctx.ctx); + let zero = int32.const_zero(); + let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); + let body_bb = ctx.ctx.append_basic_block(current, "for.body"); + let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = if orelse.is_empty() { + cont_bb + } else { + ctx.ctx.append_basic_block(current, "for.orelse") + }; + + // Whether the iterable is a range() expression + let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); + + // The BB containing the increment expression + let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); + // The BB containing the loop condition check + let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); + + // store loop bb information and restore it later + let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); + + let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { + v.to_basic_value_enum( + ctx, + generator, + iter.custom.unwrap(), + )? + } else { + return Ok(()) + }; + if is_iterable_range_expr { + let iter_val = iter_val.into_pointer_value(); + // Internal variable for loop; Cannot be assigned + let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; + // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed + let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? else { + unreachable!() + }; + let (start, stop, step) = destructure_range(ctx, iter_val); + + ctx.builder.build_store(i, start); + + // Check "If step is zero, ValueError is raised." + let rangenez = ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), ""); + ctx.make_assert( + generator, + rangenez, + "ValueError", + "range() arg 3 must not be zero", + [None, None, None], + ctx.current_loc + ); + ctx.builder.build_unconditional_branch(cond_bb); + + { + ctx.builder.position_at_end(cond_bb); + ctx.builder.build_conditional_branch( + gen_in_range_check( + ctx, + ctx.builder.build_load(i, "").into_int_value(), + stop, + step + ), + body_bb, + orelse_bb, + ); + } + + ctx.builder.position_at_end(incr_bb); + let next_i = ctx.builder.build_int_add( + ctx.builder.build_load(i, "").into_int_value(), + step, + "inc", + ); + ctx.builder.build_store(i, next_i); + ctx.builder.build_unconditional_branch(cond_bb); + + ctx.builder.position_at_end(body_bb); + ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value()); + generator.gen_block(ctx, body.iter())?; + } else { + let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; + ctx.builder.build_store(index_addr, size_t.const_zero()); + let len = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero, int32.const_int(1, false)], + Some("len") + ) + .into_int_value(); + ctx.builder.build_unconditional_branch(cond_bb); + + ctx.builder.position_at_end(cond_bb); + let index = ctx.builder.build_load(index_addr, "for.index").into_int_value(); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond"); + ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb); + + ctx.builder.position_at_end(incr_bb); + let index = ctx.builder.build_load(index_addr, "").into_int_value(); + let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc"); + ctx.builder.build_store(index_addr, inc); + ctx.builder.build_unconditional_branch(cond_bb); + + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) + .into_pointer_value(); + let index = ctx.builder.build_load(index_addr, "for.index").into_int_value(); + let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); + generator.gen_assign(ctx, target, val.into())?; + generator.gen_block(ctx, body.iter())?; } + + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } + + if !ctx.is_terminated() { + ctx.builder.build_unconditional_branch(incr_bb); + } + + if !orelse.is_empty() { + ctx.builder.position_at_end(orelse_bb); + generator.gen_block(ctx, orelse.iter())?; + if !ctx.is_terminated() { + ctx.builder.build_unconditional_branch(cont_bb); + } + } + + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } + + ctx.builder.position_at_end(cont_bb); + ctx.loop_target = loop_bb; + Ok(()) } @@ -406,66 +405,68 @@ pub fn gen_while( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - if let StmtKind::While { test, body, orelse, .. } = &stmt.node { - // var_assignment static values may be changed in another branch - // if so, remove the static value as it may not be correct in this branch - let var_assignment = ctx.var_assignment.clone(); - - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let test_bb = ctx.ctx.append_basic_block(current, "while.test"); - let body_bb = ctx.ctx.append_basic_block(current, "while.body"); - let cont_bb = ctx.ctx.append_basic_block(current, "while.cont"); - // if there is no orelse, we just go to cont_bb - let orelse_bb = - if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "while.orelse") }; - // store loop bb information and restore it later - let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); - ctx.builder.build_unconditional_branch(test_bb); - ctx.builder.position_at_end(test_bb); - let test = if let Some(v) = generator.gen_expr(ctx, test)? { - v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? - } else { - for bb in [body_bb, cont_bb] { - ctx.builder.position_at_end(bb); - ctx.builder.build_unreachable(); - } - - return Ok(()) - }; - if let BasicValueEnum::IntValue(test) = test { - ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb); - } else { - unreachable!() - }; - ctx.builder.position_at_end(body_bb); - generator.gen_block(ctx, body.iter())?; - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } - } - if !ctx.is_terminated() { - ctx.builder.build_unconditional_branch(test_bb); - } - if !orelse.is_empty() { - ctx.builder.position_at_end(orelse_bb); - generator.gen_block(ctx, orelse.iter())?; - if !ctx.is_terminated() { - ctx.builder.build_unconditional_branch(cont_bb); - } - } - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } - } - ctx.builder.position_at_end(cont_bb); - ctx.loop_target = loop_bb; - } else { + let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() + }; + + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "while.test"); + let body_bb = ctx.ctx.append_basic_block(current, "while.body"); + let cont_bb = ctx.ctx.append_basic_block(current, "while.cont"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = + if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "while.orelse") }; + // store loop bb information and restore it later + let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let test = if let Some(v) = generator.gen_expr(ctx, test)? { + v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? + } else { + for bb in [body_bb, cont_bb] { + ctx.builder.position_at_end(bb); + ctx.builder.build_unreachable(); + } + + return Ok(()) + }; + let BasicValueEnum::IntValue(test) = test else { + unreachable!() + }; + + ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb); + + ctx.builder.position_at_end(body_bb); + generator.gen_block(ctx, body.iter())?; + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } } + if !ctx.is_terminated() { + ctx.builder.build_unconditional_branch(test_bb); + } + if !orelse.is_empty() { + ctx.builder.position_at_end(orelse_bb); + generator.gen_block(ctx, orelse.iter())?; + if !ctx.is_terminated() { + ctx.builder.build_unconditional_branch(cont_bb); + } + } + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } + ctx.builder.position_at_end(cont_bb); + ctx.loop_target = loop_bb; + Ok(()) } @@ -475,67 +476,68 @@ pub fn gen_if( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - if let StmtKind::If { test, body, orelse, .. } = &stmt.node { - // var_assignment static values may be changed in another branch - // if so, remove the static value as it may not be correct in this branch - let var_assignment = ctx.var_assignment.clone(); + let StmtKind::If { test, body, orelse, .. } = &stmt.node else { + unreachable!() + }; - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let test_bb = ctx.ctx.append_basic_block(current, "if.test"); - let body_bb = ctx.ctx.append_basic_block(current, "if.body"); - let mut cont_bb = None; - // if there is no orelse, we just go to cont_bb - let orelse_bb = if orelse.is_empty() { - cont_bb = Some(ctx.ctx.append_basic_block(current, "if.cont")); - cont_bb.unwrap() - } else { - ctx.ctx.append_basic_block(current, "if.orelse") - }; - ctx.builder.build_unconditional_branch(test_bb); - ctx.builder.position_at_end(test_bb); - let test = generator - .gen_expr(ctx, test) - .and_then(|v| v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose())?; - if let Some(BasicValueEnum::IntValue(test)) = test { - ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb); - }; - ctx.builder.position_at_end(body_bb); - generator.gen_block(ctx, body.iter())?; - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "if.test"); + let body_bb = ctx.ctx.append_basic_block(current, "if.body"); + let mut cont_bb = None; + // if there is no orelse, we just go to cont_bb + let orelse_bb = if orelse.is_empty() { + cont_bb = Some(ctx.ctx.append_basic_block(current, "if.cont")); + cont_bb.unwrap() + } else { + ctx.ctx.append_basic_block(current, "if.orelse") + }; + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let test = generator + .gen_expr(ctx, test) + .and_then(|v| v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose())?; + if let Some(BasicValueEnum::IntValue(test)) = test { + ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb); + }; + ctx.builder.position_at_end(body_bb); + generator.gen_block(ctx, body.iter())?; + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; } + } + if !ctx.is_terminated() { + if cont_bb.is_none() { + cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); + } + ctx.builder.build_unconditional_branch(cont_bb.unwrap()); + } + if !orelse.is_empty() { + ctx.builder.position_at_end(orelse_bb); + generator.gen_block(ctx, orelse.iter())?; if !ctx.is_terminated() { if cont_bb.is_none() { cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); } ctx.builder.build_unconditional_branch(cont_bb.unwrap()); } - if !orelse.is_empty() { - ctx.builder.position_at_end(orelse_bb); - generator.gen_block(ctx, orelse.iter())?; - if !ctx.is_terminated() { - if cont_bb.is_none() { - cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); - } - ctx.builder.build_unconditional_branch(cont_bb.unwrap()); - } - } - if let Some(cont_bb) = cont_bb { - ctx.builder.position_at_end(cont_bb); - } - for (k, (_, _, counter)) in &var_assignment { - let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); - if counter != counter2 { - *static_val = None; - } - } - } else { - unreachable!() } + if let Some(cont_bb) = cont_bb { + ctx.builder.position_at_end(cont_bb); + } + for (k, (_, _, counter)) in &var_assignment { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } + Ok(()) } @@ -595,16 +597,16 @@ pub fn exn_constructor<'ctx>( let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let zelf_id = { - if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) { - obj_id.0 - } else { + let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) else { unreachable!() - } + }; + obj_id.0 }; let defs = ctx.top_level.definitions.read(); let def = defs[zelf_id].read(); - let zelf_name = - if let TopLevelDef::Class { name, .. } = &*def { *name } else { unreachable!() }; + let TopLevelDef::Class { name: zelf_name, .. } = &*def else { + unreachable!() + }; let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); unsafe { let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); @@ -715,320 +717,321 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, 'a>, target: &Stmt>, ) -> Result<(), String> { - if let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node { - // if we need to generate anything related to exception, we must have personality defined - let personality_symbol = ctx.top_level.personality_symbol.as_ref().unwrap(); - let personality = ctx.module.get_function(personality_symbol).unwrap_or_else(|| { - let ty = ctx.ctx.i32_type().fn_type(&[], true); - ctx.module.add_function(personality_symbol, ty, None) - }); - let exception_type = ctx.get_llvm_type(generator, ctx.primitives.exception); - let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::default()); - let current_block = ctx.builder.get_insert_block().unwrap(); - let current_fun = current_block.get_parent().unwrap(); - let landingpad = ctx.ctx.append_basic_block(current_fun, "try.landingpad"); - let dispatcher = ctx.ctx.append_basic_block(current_fun, "try.dispatch"); - let mut dispatcher_end = dispatcher; - ctx.builder.position_at_end(dispatcher); - let exn = ctx.builder.build_phi(exception_type, "exn"); - ctx.builder.position_at_end(current_block); + let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else { + unreachable!() + }; - let mut cleanup = None; - let mut old_loop_target = None; - let mut old_return = None; - let mut final_data = None; - let has_cleanup = !finalbody.is_empty(); - if has_cleanup { - let final_state = generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?; - final_data = Some((final_state, Vec::new(), Vec::new())); - if let Some((continue_target, break_target)) = ctx.loop_target { - let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); - let continue_proxy = ctx.ctx.append_basic_block(current_fun, "try.continue"); - final_proxy(ctx, break_target, break_proxy, final_data.as_mut().unwrap()); - final_proxy(ctx, continue_target, continue_proxy, final_data.as_mut().unwrap()); - old_loop_target = ctx.loop_target.replace((continue_proxy, break_proxy)); - } - let return_proxy = ctx.ctx.append_basic_block(current_fun, "try.return"); - if let Some(return_target) = ctx.return_target { - final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); - } else { - let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target"); - ctx.builder.position_at_end(return_target); - let return_value = ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret")); - ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)); - ctx.builder.position_at_end(current_block); - final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); - } - old_return = ctx.return_target.replace(return_proxy); - cleanup = Some(ctx.ctx.append_basic_block(current_fun, "try.cleanup")); - } + // if we need to generate anything related to exception, we must have personality defined + let personality_symbol = ctx.top_level.personality_symbol.as_ref().unwrap(); + let personality = ctx.module.get_function(personality_symbol).unwrap_or_else(|| { + let ty = ctx.ctx.i32_type().fn_type(&[], true); + ctx.module.add_function(personality_symbol, ty, None) + }); + let exception_type = ctx.get_llvm_type(generator, ctx.primitives.exception); + let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::default()); + let current_block = ctx.builder.get_insert_block().unwrap(); + let current_fun = current_block.get_parent().unwrap(); + let landingpad = ctx.ctx.append_basic_block(current_fun, "try.landingpad"); + let dispatcher = ctx.ctx.append_basic_block(current_fun, "try.dispatch"); + let mut dispatcher_end = dispatcher; + ctx.builder.position_at_end(dispatcher); + let exn = ctx.builder.build_phi(exception_type, "exn"); + ctx.builder.position_at_end(current_block); - let mut clauses = Vec::new(); - let mut found_catch_all = false; - for handler_node in handlers { - let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; - // none or Exception - if type_.is_none() - || ctx - .unifier - .unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) - { - clauses.push(None); - found_catch_all = true; - break; - } - - let type_ = type_.as_ref().unwrap(); - let exn_name = ctx.resolver.get_type_name( - &ctx.top_level.definitions.read(), - &mut ctx.unifier, - type_.custom.unwrap(), - ); - let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { - *obj_id - } else { - unreachable!() - }; - let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); - let exn_id = ctx.resolver.get_string_id(&exception_name); - let exn_id_global = - ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{exn_id}")); - exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false)); - clauses.push(Some(exn_id_global.as_pointer_value().as_basic_value_enum())); - } - let mut all_clauses = clauses.clone(); - if let Some(old_clauses) = &ctx.outer_catch_clauses { - if !found_catch_all { - all_clauses.extend_from_slice(&old_clauses.0); - } - } - let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn)); - let old_unwind = ctx.unwind_target.replace(landingpad); - generator.gen_block(ctx, body.iter())?; - if ctx.builder.get_insert_block().unwrap().get_terminator().is_none() { - generator.gen_block(ctx, orelse.iter())?; - } - let body = ctx.builder.get_insert_block().unwrap(); - // reset old_clauses and old_unwind - let (all_clauses, _, _) = ctx.outer_catch_clauses.take().unwrap(); - ctx.outer_catch_clauses = old_clauses; - ctx.unwind_target = old_unwind; - ctx.return_target = old_return; - ctx.loop_target = old_loop_target.or(ctx.loop_target).take(); - - let old_unwind = if finalbody.is_empty() { - None - } else { - let final_landingpad = ctx.ctx.append_basic_block(current_fun, "try.catch.final"); - ctx.builder.position_at_end(final_landingpad); - ctx.builder.build_landing_pad( - ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), - personality, - &[], - true, - "try.catch.final", - ); - ctx.builder.build_unconditional_branch(cleanup.unwrap()); - ctx.builder.position_at_end(body); - ctx.unwind_target.replace(final_landingpad) - }; - - // run end_catch before continue/break/return - let mut final_proxy_lambda = - |ctx: &mut CodeGenContext<'ctx, 'a>, - target: BasicBlock<'ctx>, - block: BasicBlock<'ctx>| final_proxy(ctx, target, block, final_data.as_mut().unwrap()); - let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>, - target: BasicBlock<'ctx>, - block: BasicBlock<'ctx>| { - ctx.builder.position_at_end(block); - ctx.builder.build_unconditional_branch(target); - ctx.builder.position_at_end(body); - }; - let redirect = if has_cleanup { - &mut final_proxy_lambda - as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) - } else { - &mut redirect_lambda - as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) - }; - let resume = get_builtins(generator, ctx, "__nac3_resume"); - let end_catch = get_builtins(generator, ctx, "__nac3_end_catch"); - if let Some((continue_target, break_target)) = ctx.loop_target.take() { + let mut cleanup = None; + let mut old_loop_target = None; + let mut old_return = None; + let mut final_data = None; + let has_cleanup = !finalbody.is_empty(); + if has_cleanup { + let final_state = generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?; + final_data = Some((final_state, Vec::new(), Vec::new())); + if let Some((continue_target, break_target)) = ctx.loop_target { let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); let continue_proxy = ctx.ctx.append_basic_block(current_fun, "try.continue"); - ctx.builder.position_at_end(break_proxy); - ctx.builder.build_call(end_catch, &[], "end_catch"); - ctx.builder.position_at_end(continue_proxy); - ctx.builder.build_call(end_catch, &[], "end_catch"); - ctx.builder.position_at_end(body); - redirect(ctx, break_target, break_proxy); - redirect(ctx, continue_target, continue_proxy); - ctx.loop_target = Some((continue_proxy, break_proxy)); - old_loop_target = Some((continue_target, break_target)); + final_proxy(ctx, break_target, break_proxy, final_data.as_mut().unwrap()); + final_proxy(ctx, continue_target, continue_proxy, final_data.as_mut().unwrap()); + old_loop_target = ctx.loop_target.replace((continue_proxy, break_proxy)); } let return_proxy = ctx.ctx.append_basic_block(current_fun, "try.return"); - ctx.builder.position_at_end(return_proxy); - ctx.builder.build_call(end_catch, &[], "end_catch"); - let return_target = ctx.return_target.take().unwrap_or_else(|| { - let doreturn = ctx.ctx.append_basic_block(current_fun, "try.doreturn"); - ctx.builder.position_at_end(doreturn); + if let Some(return_target) = ctx.return_target { + final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); + } else { + let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target"); + ctx.builder.position_at_end(return_target); let return_value = ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret")); ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)); - doreturn - }); - redirect(ctx, return_target, return_proxy); - ctx.return_target = Some(return_proxy); - old_return = Some(return_target); - - let mut post_handlers = Vec::new(); - - let exnid = if handlers.is_empty() { - None - } else { - ctx.builder.position_at_end(dispatcher); - unsafe { - let zero = ctx.ctx.i32_type().const_zero(); - let exnid_ptr = ctx.builder.build_gep( - exn.as_basic_value().into_pointer_value(), - &[zero, zero], - "exnidptr", - ); - Some(ctx.builder.build_load(exnid_ptr, "exnid")) - } - }; - - for (handler_node, exn_type) in handlers.iter().zip(clauses.iter()) { - let ExcepthandlerKind::ExceptHandler { type_, name, body } = &handler_node.node; - let handler_bb = ctx.ctx.append_basic_block(current_fun, "try.handler"); - ctx.builder.position_at_end(handler_bb); - if let Some(name) = name { - let exn_ty = ctx.get_llvm_type(generator, type_.as_ref().unwrap().custom.unwrap()); - let exn_store = generator.gen_var_alloc(ctx, exn_ty, Some("try.exn_store.addr"))?; - ctx.var_assignment.insert(*name, (exn_store, None, 0)); - ctx.builder.build_store(exn_store, exn.as_basic_value()); - } - generator.gen_block(ctx, body.iter())?; - let current = ctx.builder.get_insert_block().unwrap(); - // only need to call end catch if not terminated - // otherwise, we already handled in return/break/continue/raise - if current.get_terminator().is_none() { - ctx.builder.build_call(end_catch, &[], "end_catch"); - } - post_handlers.push(current); - ctx.builder.position_at_end(dispatcher_end); - if let Some(exn_type) = exn_type { - let dispatcher_cont = - ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont"); - let actual_id = exnid.unwrap().into_int_value(); - let expected_id = ctx - .builder - .build_load(exn_type.into_pointer_value(), "expected_id") - .into_int_value(); - let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck"); - ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont); - dispatcher_end = dispatcher_cont; - } else { - ctx.builder.build_unconditional_branch(handler_bb); - break; - } + ctx.builder.position_at_end(current_block); + final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); } - - ctx.unwind_target = old_unwind; - ctx.loop_target = old_loop_target.or(ctx.loop_target).take(); - ctx.return_target = old_return; - - ctx.builder.position_at_end(landingpad); - let clauses: Vec<_> = if finalbody.is_empty() { &all_clauses } else { &clauses } - .iter() - .map(|v| v.unwrap_or(ptr_type.const_zero().into())) - .collect(); - let landingpad_value = ctx - .builder - .build_landing_pad( - ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), - personality, - &clauses, - has_cleanup, - "try.landingpad", - ) - .into_struct_value(); - let exn_val = ctx.builder.build_extract_value(landingpad_value, 1, "exn").unwrap(); - ctx.builder.build_unconditional_branch(dispatcher); - exn.add_incoming(&[(&exn_val, landingpad)]); - - if dispatcher_end.get_terminator().is_none() { - ctx.builder.position_at_end(dispatcher_end); - if let Some(cleanup) = cleanup { - ctx.builder.build_unconditional_branch(cleanup); - } else if let Some((_, outer_dispatcher, phi)) = ctx.outer_catch_clauses { - phi.add_incoming(&[(&exn_val, dispatcher_end)]); - ctx.builder.build_unconditional_branch(outer_dispatcher); - } else { - ctx.build_call_or_invoke(resume, &[], "resume"); - ctx.builder.build_unreachable(); - } - } - - if finalbody.is_empty() { - let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); - if body.get_terminator().is_none() { - ctx.builder.position_at_end(body); - ctx.builder.build_unconditional_branch(tail); - } - if matches!(cleanup, Some(cleanup) if cleanup.get_terminator().is_none()) { - ctx.builder.position_at_end(cleanup.unwrap()); - ctx.builder.build_unconditional_branch(tail); - } - for post_handler in post_handlers { - if post_handler.get_terminator().is_none() { - ctx.builder.position_at_end(post_handler); - ctx.builder.build_unconditional_branch(tail); - } - } - ctx.builder.position_at_end(tail); - } else { - // exception path - let cleanup = cleanup.unwrap(); - ctx.builder.position_at_end(cleanup); - generator.gen_block(ctx, finalbody.iter())?; - if !ctx.is_terminated() { - ctx.build_call_or_invoke(resume, &[], "resume"); - ctx.builder.build_unreachable(); - } - - // normal path - let (final_state, mut final_targets, final_paths) = final_data.unwrap(); - let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); - final_targets.push(tail); - let finalizer = ctx.ctx.append_basic_block(current_fun, "try.finally"); - ctx.builder.position_at_end(finalizer); - generator.gen_block(ctx, finalbody.iter())?; - if !ctx.is_terminated() { - let dest = ctx.builder.build_load(final_state, "final_dest"); - ctx.builder.build_indirect_branch(dest, &final_targets); - } - for block in &final_paths { - if block.get_terminator().is_none() { - ctx.builder.position_at_end(*block); - ctx.builder.build_unconditional_branch(finalizer); - } - } - for block in [body].iter().chain(post_handlers.iter()) { - if block.get_terminator().is_none() { - ctx.builder.position_at_end(*block); - unsafe { - ctx.builder.build_store(final_state, tail.get_address().unwrap()); - } - ctx.builder.build_unconditional_branch(finalizer); - } - } - ctx.builder.position_at_end(tail); - } - Ok(()) - } else { - unreachable!() + old_return = ctx.return_target.replace(return_proxy); + cleanup = Some(ctx.ctx.append_basic_block(current_fun, "try.cleanup")); } + + let mut clauses = Vec::new(); + let mut found_catch_all = false; + for handler_node in handlers { + let ExcepthandlerKind::ExceptHandler { type_, .. } = &handler_node.node; + // none or Exception + if type_.is_none() + || ctx + .unifier + .unioned(type_.as_ref().unwrap().custom.unwrap(), ctx.primitives.exception) + { + clauses.push(None); + found_catch_all = true; + break; + } + + let type_ = type_.as_ref().unwrap(); + let exn_name = ctx.resolver.get_type_name( + &ctx.top_level.definitions.read(), + &mut ctx.unifier, + type_.custom.unwrap(), + ); + let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { + *obj_id + } else { + unreachable!() + }; + let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); + let exn_id = ctx.resolver.get_string_id(&exception_name); + let exn_id_global = + ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{exn_id}")); + exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false)); + clauses.push(Some(exn_id_global.as_pointer_value().as_basic_value_enum())); + } + let mut all_clauses = clauses.clone(); + if let Some(old_clauses) = &ctx.outer_catch_clauses { + if !found_catch_all { + all_clauses.extend_from_slice(&old_clauses.0); + } + } + let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn)); + let old_unwind = ctx.unwind_target.replace(landingpad); + generator.gen_block(ctx, body.iter())?; + if ctx.builder.get_insert_block().unwrap().get_terminator().is_none() { + generator.gen_block(ctx, orelse.iter())?; + } + let body = ctx.builder.get_insert_block().unwrap(); + // reset old_clauses and old_unwind + let (all_clauses, _, _) = ctx.outer_catch_clauses.take().unwrap(); + ctx.outer_catch_clauses = old_clauses; + ctx.unwind_target = old_unwind; + ctx.return_target = old_return; + ctx.loop_target = old_loop_target.or(ctx.loop_target).take(); + + let old_unwind = if finalbody.is_empty() { + None + } else { + let final_landingpad = ctx.ctx.append_basic_block(current_fun, "try.catch.final"); + ctx.builder.position_at_end(final_landingpad); + ctx.builder.build_landing_pad( + ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), + personality, + &[], + true, + "try.catch.final", + ); + ctx.builder.build_unconditional_branch(cleanup.unwrap()); + ctx.builder.position_at_end(body); + ctx.unwind_target.replace(final_landingpad) + }; + + // run end_catch before continue/break/return + let mut final_proxy_lambda = + |ctx: &mut CodeGenContext<'ctx, 'a>, + target: BasicBlock<'ctx>, + block: BasicBlock<'ctx>| final_proxy(ctx, target, block, final_data.as_mut().unwrap()); + let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>, + target: BasicBlock<'ctx>, + block: BasicBlock<'ctx>| { + ctx.builder.position_at_end(block); + ctx.builder.build_unconditional_branch(target); + ctx.builder.position_at_end(body); + }; + let redirect = if has_cleanup { + &mut final_proxy_lambda + as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) + } else { + &mut redirect_lambda + as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) + }; + let resume = get_builtins(generator, ctx, "__nac3_resume"); + let end_catch = get_builtins(generator, ctx, "__nac3_end_catch"); + if let Some((continue_target, break_target)) = ctx.loop_target.take() { + let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); + let continue_proxy = ctx.ctx.append_basic_block(current_fun, "try.continue"); + ctx.builder.position_at_end(break_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + ctx.builder.position_at_end(continue_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + ctx.builder.position_at_end(body); + redirect(ctx, break_target, break_proxy); + redirect(ctx, continue_target, continue_proxy); + ctx.loop_target = Some((continue_proxy, break_proxy)); + old_loop_target = Some((continue_target, break_target)); + } + let return_proxy = ctx.ctx.append_basic_block(current_fun, "try.return"); + ctx.builder.position_at_end(return_proxy); + ctx.builder.build_call(end_catch, &[], "end_catch"); + let return_target = ctx.return_target.take().unwrap_or_else(|| { + let doreturn = ctx.ctx.append_basic_block(current_fun, "try.doreturn"); + ctx.builder.position_at_end(doreturn); + let return_value = ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret")); + ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)); + doreturn + }); + redirect(ctx, return_target, return_proxy); + ctx.return_target = Some(return_proxy); + old_return = Some(return_target); + + let mut post_handlers = Vec::new(); + + let exnid = if handlers.is_empty() { + None + } else { + ctx.builder.position_at_end(dispatcher); + unsafe { + let zero = ctx.ctx.i32_type().const_zero(); + let exnid_ptr = ctx.builder.build_gep( + exn.as_basic_value().into_pointer_value(), + &[zero, zero], + "exnidptr", + ); + Some(ctx.builder.build_load(exnid_ptr, "exnid")) + } + }; + + for (handler_node, exn_type) in handlers.iter().zip(clauses.iter()) { + let ExcepthandlerKind::ExceptHandler { type_, name, body } = &handler_node.node; + let handler_bb = ctx.ctx.append_basic_block(current_fun, "try.handler"); + ctx.builder.position_at_end(handler_bb); + if let Some(name) = name { + let exn_ty = ctx.get_llvm_type(generator, type_.as_ref().unwrap().custom.unwrap()); + let exn_store = generator.gen_var_alloc(ctx, exn_ty, Some("try.exn_store.addr"))?; + ctx.var_assignment.insert(*name, (exn_store, None, 0)); + ctx.builder.build_store(exn_store, exn.as_basic_value()); + } + generator.gen_block(ctx, body.iter())?; + let current = ctx.builder.get_insert_block().unwrap(); + // only need to call end catch if not terminated + // otherwise, we already handled in return/break/continue/raise + if current.get_terminator().is_none() { + ctx.builder.build_call(end_catch, &[], "end_catch"); + } + post_handlers.push(current); + ctx.builder.position_at_end(dispatcher_end); + if let Some(exn_type) = exn_type { + let dispatcher_cont = + ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont"); + let actual_id = exnid.unwrap().into_int_value(); + let expected_id = ctx + .builder + .build_load(exn_type.into_pointer_value(), "expected_id") + .into_int_value(); + let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck"); + ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont); + dispatcher_end = dispatcher_cont; + } else { + ctx.builder.build_unconditional_branch(handler_bb); + break; + } + } + + ctx.unwind_target = old_unwind; + ctx.loop_target = old_loop_target.or(ctx.loop_target).take(); + ctx.return_target = old_return; + + ctx.builder.position_at_end(landingpad); + let clauses: Vec<_> = if finalbody.is_empty() { &all_clauses } else { &clauses } + .iter() + .map(|v| v.unwrap_or(ptr_type.const_zero().into())) + .collect(); + let landingpad_value = ctx + .builder + .build_landing_pad( + ctx.ctx.struct_type(&[ptr_type.into(), exception_type], false), + personality, + &clauses, + has_cleanup, + "try.landingpad", + ) + .into_struct_value(); + let exn_val = ctx.builder.build_extract_value(landingpad_value, 1, "exn").unwrap(); + ctx.builder.build_unconditional_branch(dispatcher); + exn.add_incoming(&[(&exn_val, landingpad)]); + + if dispatcher_end.get_terminator().is_none() { + ctx.builder.position_at_end(dispatcher_end); + if let Some(cleanup) = cleanup { + ctx.builder.build_unconditional_branch(cleanup); + } else if let Some((_, outer_dispatcher, phi)) = ctx.outer_catch_clauses { + phi.add_incoming(&[(&exn_val, dispatcher_end)]); + ctx.builder.build_unconditional_branch(outer_dispatcher); + } else { + ctx.build_call_or_invoke(resume, &[], "resume"); + ctx.builder.build_unreachable(); + } + } + + if finalbody.is_empty() { + let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); + if body.get_terminator().is_none() { + ctx.builder.position_at_end(body); + ctx.builder.build_unconditional_branch(tail); + } + if matches!(cleanup, Some(cleanup) if cleanup.get_terminator().is_none()) { + ctx.builder.position_at_end(cleanup.unwrap()); + ctx.builder.build_unconditional_branch(tail); + } + for post_handler in post_handlers { + if post_handler.get_terminator().is_none() { + ctx.builder.position_at_end(post_handler); + ctx.builder.build_unconditional_branch(tail); + } + } + ctx.builder.position_at_end(tail); + } else { + // exception path + let cleanup = cleanup.unwrap(); + ctx.builder.position_at_end(cleanup); + generator.gen_block(ctx, finalbody.iter())?; + if !ctx.is_terminated() { + ctx.build_call_or_invoke(resume, &[], "resume"); + ctx.builder.build_unreachable(); + } + + // normal path + let (final_state, mut final_targets, final_paths) = final_data.unwrap(); + let tail = ctx.ctx.append_basic_block(current_fun, "try.tail"); + final_targets.push(tail); + let finalizer = ctx.ctx.append_basic_block(current_fun, "try.finally"); + ctx.builder.position_at_end(finalizer); + generator.gen_block(ctx, finalbody.iter())?; + if !ctx.is_terminated() { + let dest = ctx.builder.build_load(final_state, "final_dest"); + ctx.builder.build_indirect_branch(dest, &final_targets); + } + for block in &final_paths { + if block.get_terminator().is_none() { + ctx.builder.position_at_end(*block); + ctx.builder.build_unconditional_branch(finalizer); + } + } + for block in [body].iter().chain(post_handlers.iter()) { + if block.get_terminator().is_none() { + ctx.builder.position_at_end(*block); + unsafe { + ctx.builder.build_store(final_state, tail.get_address().unwrap()); + } + ctx.builder.build_unconditional_branch(finalizer); + } + } + ctx.builder.position_at_end(tail); + } + + Ok(()) } /// See [`CodeGenerator::gen_with`]. diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 54a243a..2f08ecc 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -528,11 +528,11 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() { - name.to_string() - } else { + let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { unreachable!("expected class definition") - } + }; + + name.to_string() }, &mut |id| format!("typevar{id}"), &mut None, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 298f189..0b8ccde 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -421,11 +421,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { generator, expect_ty, )?; - if let BasicValueEnum::PointerValue(ptr) = obj_val { - Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").into())) - } else { + let BasicValueEnum::PointerValue(ptr) = obj_val else { unreachable!("option must be ptr") - } + }; + + Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").into())) }, )))), loc: None, @@ -446,11 +446,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { generator, expect_ty, )?; - if let BasicValueEnum::PointerValue(ptr) = obj_val { - Ok(Some(ctx.builder.build_is_null(ptr, "is_none").into())) - } else { + let BasicValueEnum::PointerValue(ptr) = obj_val else { unreachable!("option must be ptr") - } + }; + + Ok(Some(ctx.builder.build_is_null(ptr, "is_none").into())) }, )))), loc: None, @@ -686,7 +686,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { val } else { - unreachable!(); + unreachable!() }; Ok(Some(res)) }, @@ -762,7 +762,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { val } else { - unreachable!(); + unreachable!() }; Ok(Some(res)) }, @@ -1361,7 +1361,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { } else if is_type(m_ty, n_ty) && is_type(n_ty, float) { ("llvm.minnum.f64", llvm_f64) } else { - unreachable!(); + unreachable!() }; let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false); @@ -1423,7 +1423,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { } else if is_type(m_ty, n_ty) && is_type(n_ty, float) { ("llvm.maxnum.f64", llvm_f64) } else { - unreachable!(); + unreachable!() }; let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false); @@ -1480,7 +1480,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { is_float = true; ("llvm.fabs.f64", llvm_f64) } else { - unreachable!(); + unreachable!() }; let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { let fn_type = if is_float { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index e3fb80b..68cfb72 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -300,12 +300,12 @@ impl TopLevelComposer { // get the methods into the top level class_def for (name, _, id, ty, ..) in &class_method_name_def_ids { let mut class_def = class_def_ast.0.write(); - if let TopLevelDef::Class { methods, .. } = &mut *class_def { - methods.push((*name, *ty, *id)); - self.method_class.insert(*id, DefinitionId(class_def_id)); - } else { + let TopLevelDef::Class { methods, .. } = &mut *class_def else { unreachable!() - } + }; + + methods.push((*name, *ty, *id)); + self.method_class.insert(*id, DefinitionId(class_def_id)); } // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order self.definition_ast_list.push(class_def_ast); @@ -385,14 +385,13 @@ impl TopLevelComposer { let mut class_def = class_def.write(); let (class_bases_ast, class_def_type_vars, class_resolver) = { if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { - if let Some(ast::Located { + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast - { - (bases, type_vars, resolver) - } else { - unreachable!("must be both class") - } + }) = class_ast else { + unreachable!() + }; + + (bases, type_vars, resolver) } else { return Ok(()); } @@ -515,15 +514,14 @@ impl TopLevelComposer { ancestors, resolver, object_id, type_vars, .. } = &mut *class_def { - if let Some(ast::Located { + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast - { - (object_id, bases, ancestors, resolver, type_vars) - } else { - unreachable!("must be both class") - } + }) = class_ast else { + unreachable!() + }; + + (object_id, bases, ancestors, resolver, type_vars) } else { return Ok(()); } @@ -659,32 +657,31 @@ impl TopLevelComposer { .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { // if inherited from Exception, the body should be a pass - if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node { - for stmt in body { - if matches!( - stmt.node, - ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } - ) { - return Err(HashSet::from([ - "Classes inherited from exception should have no custom fields/methods".into() - ])) - } - } - } else { + let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { unreachable!() + }; + + for stmt in body { + if matches!( + stmt.node, + ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } + ) { + return Err(HashSet::from([ + "Classes inherited from exception should have no custom fields/methods".into() + ])) + } } } } // deal with ancestor of Exception object - if let TopLevelDef::Class { name, ancestors, object_id, .. } = - &mut *self.definition_ast_list[7].0.write() - { - assert_eq!(*name, "Exception".into()); - ancestors.push(make_self_type_annotation(&[], *object_id)); - } else { - unreachable!(); - } + let TopLevelDef::Class { name, ancestors, object_id, .. } = + &mut *self.definition_ast_list[7].0.write() else { + unreachable!() + }; + + assert_eq!(*name, "Exception".into()); + ancestors.push(make_self_type_annotation(&[], *object_id)); Ok(()) } @@ -775,26 +772,26 @@ impl TopLevelComposer { } } for ty in subst_list.unwrap() { - if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) { - let mut new_fields = HashMap::new(); - let mut need_subst = false; - for (name, (ty, mutable)) in fields { - let substituted = unifier.subst(*ty, params); - need_subst |= substituted.is_some(); - new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); - } - if need_subst { - let new_ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - params: params.clone(), - fields: new_fields, - }); - if let Err(e) = unifier.unify(ty, new_ty) { - errors.insert(e.to_display(unifier).to_string()); - } - } - } else { + let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else { unreachable!() + }; + + let mut new_fields = HashMap::new(); + let mut need_subst = false; + for (name, (ty, mutable)) in fields { + let substituted = unifier.subst(*ty, params); + need_subst |= substituted.is_some(); + new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); + } + if need_subst { + let new_ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + params: params.clone(), + fields: new_fields, + }); + if let Err(e) = unifier.unify(ty, new_ty) { + errors.insert(e.to_display(unifier).to_string()); + } } } if !errors.is_empty() { @@ -833,203 +830,199 @@ impl TopLevelComposer { return Ok(()); }; - if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = - function_def - { - if matches!(unifier.get_ty(*dummy_ty).as_ref(), TypeEnum::TFunc(_)) { - // already have a function type, is class method, skip - return Ok(()); - } - if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { - let resolver = resolver.as_ref(); - let resolver = resolver.unwrap(); - let resolver = &**resolver; - - let mut function_var_map: HashMap = HashMap::new(); - let arg_types = { - // make sure no duplicate parameter - let mut defined_parameter_name: HashSet<_> = HashSet::new(); - for x in &args.args { - if !defined_parameter_name.insert(x.node.arg) - || keyword_list.contains(&x.node.arg) - { - return Err(HashSet::from([ - format!( - "top level function must have unique parameter names \ - and names should not be the same as the keywords (at {})", - x.location - ), - ])) - } - } - - let arg_with_default: Vec<( - &ast::Located>, - Option<&ast::Expr>, - )> = args - .args - .iter() - .rev() - .zip( - args.defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)), - ) - .collect_vec(); - - arg_with_default - .iter() - .rev() - .map(|(x, default)| -> Result> { - let annotation = x - .node - .annotation - .as_ref() - .ok_or_else(|| HashSet::from([ - format!( - "function parameter `{}` needs type annotation at {}", - x.node.arg, x.location - ), - ]))? - .as_ref(); - - let type_annotation = parse_ast_to_type_annotation_kinds( - resolver, - temp_def_list.as_slice(), - unifier, - primitives_store, - annotation, - // NOTE: since only class need this, for function - // it should be fine to be empty map - HashMap::new(), - None, - )?; - - let type_vars_within = - get_type_var_contained_in_type_annotation(&type_annotation) - .into_iter() - .map(|x| -> Result<(u32, Type), HashSet> { - if let TypeAnnotation::TypeVar(ty) = x { - Ok((Self::get_var_id(ty, unifier)?, ty)) - } else { - unreachable!("must be type var annotation kind") - } - }) - .collect::, _>>()?; - for (id, ty) in type_vars_within { - if let Some(prev_ty) = function_var_map.insert(id, ty) { - // if already have the type inserted, make sure they are the same thing - assert_eq!(prev_ty, ty); - } - } - - let ty = get_type_from_type_annotation_kinds( - temp_def_list.as_ref(), - unifier, - &type_annotation, - &mut None - )?; - - Ok(FuncArg { - name: x.node.arg, - ty, - default_value: match default { - None => None, - Some(default) => Some({ - let v = Self::parse_parameter_default_value( - default, resolver, - )?; - Self::check_default_param_type( - &v, - &type_annotation, - primitives_store, - unifier, - ) - .map_err(|err| HashSet::from([ - format!("{} (at {})", err, x.location), - ]))?; - v - }), - }, - }) - }) - .collect::, _>>()? - }; - - let return_ty = { - if let Some(returns) = returns { - let return_ty_annotation = { - let return_annotation = returns.as_ref(); - parse_ast_to_type_annotation_kinds( - resolver, - &temp_def_list, - unifier, - primitives_store, - return_annotation, - // NOTE: since only class need this, for function - // it should be fine to be empty map - HashMap::new(), - None, - )? - }; - - let type_vars_within = - get_type_var_contained_in_type_annotation(&return_ty_annotation) - .into_iter() - .map(|x| -> Result<(u32, Type), HashSet> { - if let TypeAnnotation::TypeVar(ty) = x { - Ok((Self::get_var_id(ty, unifier)?, ty)) - } else { - unreachable!("must be type var here") - } - }) - .collect::, _>>()?; - for (id, ty) in type_vars_within { - if let Some(prev_ty) = function_var_map.insert(id, ty) { - // if already have the type inserted, make sure they are the same thing - assert_eq!(prev_ty, ty); - } - } - - get_type_from_type_annotation_kinds( - &temp_def_list, - unifier, - &return_ty_annotation, - &mut None - )? - } else { - primitives_store.none - } - }; - var_id.extend_from_slice(function_var_map - .iter() - .filter_map(|(id, ty)| { - if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { - None - } else { - Some(*id) - } - }) - .collect_vec() - .as_slice() - ); - let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: arg_types, - ret: return_ty, - vars: function_var_map, - })); - unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([ - e.at(Some(function_ast.location)).to_display(unifier).to_string(), - ]))?; - } else { - unreachable!("must be both function"); - } - } else { + let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def else { // not top level function def, skip return Ok(()); + }; + + if matches!(unifier.get_ty(*dummy_ty).as_ref(), TypeEnum::TFunc(_)) { + // already have a function type, is class method, skip + return Ok(()); } + let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node else { + unreachable!("must be both function"); + }; + + let resolver = resolver.as_ref(); + let resolver = resolver.unwrap(); + let resolver = &**resolver; + + let mut function_var_map: HashMap = HashMap::new(); + let arg_types = { + // make sure no duplicate parameter + let mut defined_parameter_name: HashSet<_> = HashSet::new(); + for x in &args.args { + if !defined_parameter_name.insert(x.node.arg) + || keyword_list.contains(&x.node.arg) + { + return Err(HashSet::from([format!( + "top level function must have unique parameter names \ + and names should not be the same as the keywords (at {})", + x.location + ), + ])) + }} + + let arg_with_default: Vec<( + &ast::Located>, + Option<&ast::Expr>, + )> = args + .args + .iter() + .rev() + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); + + arg_with_default + .iter() + .rev() + .map(|(x, default)| -> Result> { + let annotation = x + .node + .annotation + .as_ref() + .ok_or_else(|| HashSet::from([ + format!( + "function parameter `{}` needs type annotation at {}", + x.node.arg, x.location + ), + ]))? + .as_ref(); + + let type_annotation = parse_ast_to_type_annotation_kinds( + resolver, + temp_def_list.as_slice(), + unifier, + primitives_store, + annotation, + // NOTE: since only class need this, for function + // it should be fine to be empty map + HashMap::new(), + None, + )?; + + let type_vars_within = + get_type_var_contained_in_type_annotation(&type_annotation) + .into_iter() + .map(|x| -> Result<(u32, Type), HashSet> { + let TypeAnnotation::TypeVar(ty) = x else { + unreachable!("must be type var annotation kind") + }; + + Ok((Self::get_var_id(ty, unifier)?, ty)) + }) + .collect::, _>>()?; + for (id, ty) in type_vars_within { + if let Some(prev_ty) = function_var_map.insert(id, ty) { + // if already have the type inserted, make sure they are the same thing + assert_eq!(prev_ty, ty); + } + } + + let ty = get_type_from_type_annotation_kinds( + temp_def_list.as_ref(), + unifier, + &type_annotation, + &mut None + )?; + + Ok(FuncArg { + name: x.node.arg, + ty, + default_value: match default { + None => None, + Some(default) => Some({ + let v = Self::parse_parameter_default_value( + default, resolver, + )?; + Self::check_default_param_type( + &v, + &type_annotation, + primitives_store, + unifier, + ) + .map_err( + |err| HashSet::from([format!("{} (at {})", err, x.location), + ]))?; + v + }), + }, + }) + }) + .collect::, _>>()? + }; + + let return_ty = { + if let Some(returns) = returns { + let return_ty_annotation = { + let return_annotation = returns.as_ref(); + parse_ast_to_type_annotation_kinds( + resolver, + &temp_def_list, + unifier, + primitives_store, + return_annotation, + // NOTE: since only class need this, for function + // it should be fine to be empty map + HashMap::new(), + None, + )? + }; + + let type_vars_within = + get_type_var_contained_in_type_annotation(&return_ty_annotation) + .into_iter() + .map(|x| -> Result<(u32, Type), HashSet> { + let TypeAnnotation::TypeVar(ty) = x else { + unreachable!("must be type var here") + }; + + Ok((Self::get_var_id(ty, unifier)?, ty)) + }) + .collect::, _>>()?; + for (id, ty) in type_vars_within { + if let Some(prev_ty) = function_var_map.insert(id, ty) { + // if already have the type inserted, make sure they are the same thing + assert_eq!(prev_ty, ty); + } + } + + get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + &return_ty_annotation, + &mut None + )? + } else { + primitives_store.none + } + }; + var_id.extend_from_slice(function_var_map + .iter() + .filter_map(|(id, ty)| { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { + None + } else { + Some(*id) + } + }) + .collect_vec() + .as_slice() + ); + let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_types, + ret: return_ty, + vars: function_var_map, + })); + unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([ + e.at(Some(function_ast.location)).to_display(unifier).to_string(), + ]))?; Ok(()) }; for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { @@ -1057,6 +1050,21 @@ impl TopLevelComposer { ) -> Result<(), HashSet> { let (keyword_list, core_config) = core_info; let mut class_def = class_def.write(); + let TopLevelDef::Class { + object_id, + ancestors, + fields, + methods, + resolver, + type_vars, + .. + } = &mut *class_def else { + unreachable!("here must be toplevel class def"); + }; + let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else { + unreachable!("here must be class def ast") + }; + let ( class_id, _class_name, @@ -1067,24 +1075,8 @@ impl TopLevelComposer { class_methods_def, class_type_vars_def, class_resolver, - ) = if let TopLevelDef::Class { - object_id, - ancestors, - fields, - methods, - resolver, - type_vars, - .. - } = &mut *class_def - { - if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { - (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver) - } else { - unreachable!("here must be class def ast"); - } - } else { - unreachable!("here must be toplevel class def"); - }; + ) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver); + let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); @@ -1174,14 +1166,14 @@ impl TopLevelComposer { get_type_var_contained_in_type_annotation(&type_ann); // handle the class type var and the method type var for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVar(ty) = type_var_within { - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); - } - } else { - unreachable!("must be type var annotation"); + let TypeAnnotation::TypeVar(ty) = type_var_within else { + unreachable!("must be type var annotation") + }; + + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); } } // finish handling type vars @@ -1239,14 +1231,14 @@ impl TopLevelComposer { get_type_var_contained_in_type_annotation(&annotation); // handle the class type var and the method type var for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVar(ty) = type_var_within { - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); - } - } else { + let TypeAnnotation::TypeVar(ty) = type_var_within else { unreachable!("must be type var annotation"); + }; + + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); } } let dummy_return_type = unifier.get_dummy_var().0; @@ -1264,24 +1256,22 @@ impl TopLevelComposer { } }; - if let TopLevelDef::Function { var_id, .. } = - &mut *temp_def_list.get(method_id.0).unwrap().write() - { - var_id.extend_from_slice(method_var_map - .iter() - .filter_map(|(id, ty)| { - if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { - None - } else { - Some(*id) - } - }) - .collect_vec() - .as_slice() - ); - } else { + let TopLevelDef::Function { var_id, .. } = + &mut *temp_def_list.get(method_id.0).unwrap().write() else { unreachable!() - } + }; + var_id.extend_from_slice(method_var_map + .iter() + .filter_map(|(id, ty)| { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { + None + } else { + Some(*id) + } + }) + .collect_vec() + .as_slice() + ); let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: arg_types, ret: ret_type, @@ -1336,18 +1326,18 @@ impl TopLevelComposer { get_type_var_contained_in_type_annotation(&parsed_annotation); // handle the class type var and the method type var for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVar(t) = type_var_within { - if !class_type_vars_def.contains(&t) { - return Err(HashSet::from([ - format!( - "class fields can only use type \ + let TypeAnnotation::TypeVar(t) = type_var_within else { + unreachable!("must be type var annotation") + }; + + if !class_type_vars_def.contains(&t) { + return Err(HashSet::from([ + format!( + "class fields can only use type \ vars over which the class is generic (at {})", - annotation.location - ), - ])) - } - } else { - unreachable!("must be type var annotation"); + annotation.location + ), + ])) } } type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation); @@ -1391,14 +1381,7 @@ impl TopLevelComposer { _primitives: &PrimitiveStore, type_var_to_concrete_def: &mut HashMap, ) -> Result<(), HashSet> { - let ( - _class_id, - class_ancestor_def, - class_fields_def, - class_methods_def, - _class_type_vars_def, - _class_resolver, - ) = if let TopLevelDef::Class { + let TopLevelDef::Class { object_id, ancestors, fields, @@ -1406,102 +1389,103 @@ impl TopLevelComposer { resolver, type_vars, .. - } = class_def - { - (*object_id, ancestors, fields, methods, type_vars, resolver) - } else { - unreachable!("here must be class def ast"); + } = class_def else { + unreachable!("here must be class def ast") }; + let ( + _class_id, + class_ancestor_def, + class_fields_def, + class_methods_def, + _class_type_vars_def, + _class_resolver, + ) = (*object_id, ancestors, fields, methods, type_vars, resolver); // since when this function is called, the ancestors of the direct parent // are supposed to be already handled, so we only need to deal with the direct parent let base = class_ancestor_def.get(1).unwrap(); - if let TypeAnnotation::CustomClass { id, params: _ } = base { - let base = temp_def_list.get(id.0).unwrap(); - let base = base.read(); - if let TopLevelDef::Class { methods, fields, .. } = &*base { - // handle methods override - // since we need to maintain the order, create a new list - let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new(); - let mut is_override: HashSet = HashSet::new(); - for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { - // find if there is a method with same name in the child class - let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); - for (class_method_name, class_method_ty, class_method_defid) in - &*class_methods_def - { - if class_method_name == anc_method_name { - // ignore and handle self - // if is __init__ method, no need to check return type - let ok = class_method_name == &"__init__".into() - || Self::check_overload_function_type( - *class_method_ty, - *anc_method_ty, - unifier, - type_var_to_concrete_def, - ); - if !ok { - return Err(HashSet::from([ - format!( - "method {class_method_name} has same name as ancestors' method, but incompatible type" - ), - ])) - } - // mark it as added - is_override.insert(*class_method_name); - to_be_added = - (*class_method_name, *class_method_ty, *class_method_defid); - break; - } - } - new_child_methods.push(to_be_added); - } - // add those that are not overriding method to the new_child_methods - for (class_method_name, class_method_ty, class_method_defid) in - &*class_methods_def - { - if !is_override.contains(class_method_name) { - new_child_methods.push(( - *class_method_name, - *class_method_ty, - *class_method_defid, - )); - } - } - // use the new_child_methods to replace all the elements in `class_methods_def` - class_methods_def.drain(..); - class_methods_def.extend(new_child_methods); - - // handle class fields - let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new(); - // let mut is_override: HashSet<_> = HashSet::new(); - for (anc_field_name, anc_field_ty, mutable) in fields { - let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); - // find if there is a fields with the same name in the child class - for (class_field_name, ..) in &*class_fields_def { - if class_field_name == anc_field_name { - return Err(HashSet::from([ - format!( - "field `{class_field_name}` has already declared in the ancestor classes" - ), - ])) - } - } - new_child_fields.push(to_be_added); - } - for (class_field_name, class_field_ty, mutable) in &*class_fields_def { - if !is_override.contains(class_field_name) { - new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); - } - } - class_fields_def.drain(..); - class_fields_def.extend(new_child_fields); - } else { - unreachable!("must be top level class def") - } - } else { + let TypeAnnotation::CustomClass { id, params: _ } = base else { unreachable!("must be class type annotation") + }; + + let base = temp_def_list.get(id.0).unwrap(); + let base = base.read(); + let TopLevelDef::Class { methods, fields, .. } = &*base else { + unreachable!("must be top level class def") + }; + + // handle methods override + // since we need to maintain the order, create a new list + let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new(); + let mut is_override: HashSet = HashSet::new(); + for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { + // find if there is a method with same name in the child class + let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); + for (class_method_name, class_method_ty, class_method_defid) in + &*class_methods_def + { + if class_method_name == anc_method_name { + // ignore and handle self + // if is __init__ method, no need to check return type + let ok = class_method_name == &"__init__".into() + || Self::check_overload_function_type( + *class_method_ty, + *anc_method_ty, + unifier, + type_var_to_concrete_def, + ); + if !ok { + return Err(HashSet::from([format!( + "method {class_method_name} has same name as ancestors' method, but incompatible type"), + ])) + } + // mark it as added + is_override.insert(*class_method_name); + to_be_added = + (*class_method_name, *class_method_ty, *class_method_defid); + break; + } + } + new_child_methods.push(to_be_added); } + // add those that are not overriding method to the new_child_methods + for (class_method_name, class_method_ty, class_method_defid) in + &*class_methods_def + { + if !is_override.contains(class_method_name) { + new_child_methods.push(( + *class_method_name, + *class_method_ty, + *class_method_defid, + )); + } + } + // use the new_child_methods to replace all the elements in `class_methods_def` + class_methods_def.drain(..); + class_methods_def.extend(new_child_methods); + + // handle class fields + let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new(); + // let mut is_override: HashSet<_> = HashSet::new(); + for (anc_field_name, anc_field_ty, mutable) in fields { + let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); + // find if there is a fields with the same name in the child class + for (class_field_name, ..) in &*class_fields_def { + if class_field_name == anc_field_name { + return Err(HashSet::from([format!( + "field `{class_field_name}` has already declared in the ancestor classes"), + ])) + } + } + new_child_fields.push(to_be_added); + } + for (class_field_name, class_field_ty, mutable) in &*class_fields_def { + if !is_override.contains(class_field_name) { + new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); + } + } + class_fields_def.drain(..); + class_fields_def.extend(new_child_fields); Ok(()) } @@ -1626,14 +1610,14 @@ impl TopLevelComposer { for (name, func_sig, id) in methods { if *name == init_str_id { init_id = Some(*id); - if let TypeEnum::TFunc(FunSignature { args, vars, .. }) = - unifier.get_ty(*func_sig).as_ref() - { - constructor_args.extend_from_slice(args); - type_vars.extend(vars); - } else { + let func_ty_enum = unifier.get_ty(*func_sig); + let TypeEnum::TFunc(FunSignature { args, vars, .. }) = + func_ty_enum.as_ref() else { unreachable!("must be typeenum::tfunc") - } + }; + + constructor_args.extend_from_slice(args); + type_vars.extend(vars); } } (constructor_args, type_vars) @@ -1685,16 +1669,15 @@ impl TopLevelComposer { } for (i, signature, id) in constructors { - if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() - { - methods.push(( - init_str_id, - signature, - DefinitionId(self.definition_ast_list.len() + id), - )); - } else { + let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() else { unreachable!() - } + }; + + methods.push(( + init_str_id, + signature, + DefinitionId(self.definition_ast_list.len() + id), + )); } self.definition_ast_list.extend_from_slice(&definition_extension); @@ -1720,259 +1703,250 @@ impl TopLevelComposer { .. } = &mut *function_def { - if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = - unifier.get_ty(*signature).as_ref() - { - let mut vars = vars.clone(); - // None if is not class method - let uninst_self_type = { - if let Some(class_id) = method_class.get(&DefinitionId(id)) { - let class_def = definition_ast_list.get(class_id.0).unwrap(); - let class_def = class_def.0.read(); - if let TopLevelDef::Class { type_vars, .. } = &*class_def { - let ty_ann = make_self_type_annotation(type_vars, *class_id); - let self_ty = get_type_from_type_annotation_kinds( - &def_list, - unifier, - &ty_ann, - &mut None - )?; - vars.extend(type_vars.iter().map(|ty| - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) { - (*id, *ty) - } else { - unreachable!() - })); - Some((self_ty, type_vars.clone())) - } else { - unreachable!("must be class def") - } - } else { - None - } - }; - // carefully handle those with bounds, without bounds and no typevars - // if class methods, `vars` also contains all class typevars here - let (type_var_subst_comb, no_range_vars) = { - let mut no_ranges: Vec = Vec::new(); - let var_combs = vars - .values() - .map(|ty| { - unifier.get_instantiations(*ty).unwrap_or_else(|| { - if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) - { - let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; - no_ranges.push(rigid); - vec![rigid] - } else { - unreachable!() - } - }) - }) - .multi_cartesian_product() - .collect_vec(); - let mut result: Vec> = Vec::default(); - for comb in var_combs { - result.push(vars.keys().copied().zip(comb).collect()); - } - // NOTE: if is empty, means no type var, append a empty subst, ok to do this? - if result.is_empty() { - result.push(HashMap::new()); - } - (result, no_ranges) - }; - - for subst in type_var_subst_comb { - // for each instance - let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); - let inst_args = { - args.iter() - .map(|a| FuncArg { - name: a.name, - ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), - default_value: a.default_value.clone(), - }) - .collect_vec() - }; - let self_type = { - uninst_self_type.clone().map(|(self_type, type_vars)| { - let subst_for_self = { - let class_ty_var_ids = type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) - { - *id - } else { - unreachable!("must be type var here"); - } - }) - .collect::>(); - subst - .iter() - .filter_map(|(ty_var_id, ty_var_target)| { - if class_ty_var_ids.contains(ty_var_id) { - Some((*ty_var_id, *ty_var_target)) - } else { - None - } - }) - .collect::>() - }; - unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) - }) - }; - let mut identifiers = { - let mut result: HashSet<_> = HashSet::new(); - if self_type.is_some() { - result.insert("self".into()); - } - result.extend(inst_args.iter().map(|x| x.name)); - result - }; - let mut calls: HashMap = HashMap::new(); - let mut inferencer = Inferencer { - top_level: ctx.as_ref(), - defined_identifiers: identifiers.clone(), - function_data: &mut FunctionData { - resolver: resolver.as_ref().unwrap().clone(), - return_type: if unifier.unioned(inst_ret, primitives_ty.none) { - None - } else { - Some(inst_ret) - }, - // NOTE: allowed type vars - bound_variables: no_range_vars.clone(), - }, - unifier, - variable_mapping: { - let mut result: HashMap = HashMap::new(); - if let Some(self_ty) = self_type { - result.insert("self".into(), self_ty); - } - result.extend(inst_args.iter().map(|x| (x.name, x.ty))); - result - }, - primitives: primitives_ty, - virtual_checks: &mut Vec::new(), - calls: &mut calls, - in_handler: false, - }; - - let fun_body = - if let ast::StmtKind::FunctionDef { body, decorator_list, .. } = - ast.clone().unwrap().node - { - if !decorator_list.is_empty() - && matches!(&decorator_list[0].node, - ast::ExprKind::Name{ id, .. } if id == &"extern".into()) - { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } - if !decorator_list.is_empty() - && matches!(&decorator_list[0].node, - ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) - { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } - body - } else { - unreachable!("must be function def ast") - } - .into_iter() - .map(|b| inferencer.fold_stmt(b)) - .collect::, _>>()?; - - let returned = - inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; - { - // check virtuals - let defs = ctx.definitions.read(); - for (subtype, base, loc) in &*inferencer.virtual_checks { - let base_id = { - let base = inferencer.unifier.get_ty(*base); - if let TypeEnum::TObj { obj_id, .. } = &*base { - *obj_id - } else { - return Err(HashSet::from([ - format!("Base type should be a class (at {loc})"), - ])) - } - }; - let subtype_id = { - let ty = inferencer.unifier.get_ty(*subtype); - if let TypeEnum::TObj { obj_id, .. } = &*ty { - *obj_id - } else { - let base_repr = inferencer.unifier.stringify(*base); - let subtype_repr = inferencer.unifier.stringify(*subtype); - return Err(HashSet::from([ - format!( - "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})" - ), - ])) - } - }; - let subtype_entry = defs[subtype_id.0].read(); - if let TopLevelDef::Class { ancestors, .. } = &*subtype_entry { - let m = ancestors.iter() - .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); - if m.is_none() { - let base_repr = inferencer.unifier.stringify(*base); - let subtype_repr = inferencer.unifier.stringify(*subtype); - return Err(HashSet::from([ - format!( - "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})" - ), - ])) - } - } else { - unreachable!(); - } - } - } - if !unifier.unioned(inst_ret, primitives_ty.none) && !returned { - let def_ast_list = &definition_ast_list; - let ret_str = unifier.internal_stringify( - inst_ret, - &mut |id| { - if let TopLevelDef::Class { name, .. } = - &*def_ast_list[id].0.read() - { - name.to_string() - } else { - unreachable!("must be class id here") - } - }, - &mut |id| format!("typevar{id}"), - &mut None, - ); - return Err(HashSet::from([ - format!( - "expected return type of `{}` in function `{}` (at {})", - ret_str, - name, - ast.as_ref().unwrap().location - ), - ])) - } - - instance_to_stmt.insert( - get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), - FunInstance { - body: Arc::new(fun_body), - unifier_id: 0, - calls: Arc::new(calls), - subst, - }, - ); - } - } else { + let signature_ty_enum = unifier.get_ty(*signature); + let TypeEnum::TFunc(FunSignature { args, ret, vars }) = + signature_ty_enum.as_ref() else { unreachable!("must be typeenum::tfunc") + }; + + let mut vars = vars.clone(); + // None if is not class method + let uninst_self_type = { + if let Some(class_id) = method_class.get(&DefinitionId(id)) { + let class_def = definition_ast_list.get(class_id.0).unwrap(); + let class_def = class_def.0.read(); + let TopLevelDef::Class { type_vars, .. } = &*class_def else { + unreachable!("must be class def") + }; + + let ty_ann = make_self_type_annotation(type_vars, *class_id); + let self_ty = get_type_from_type_annotation_kinds( + &def_list, + unifier, + &ty_ann, + &mut None + )?; + vars.extend(type_vars.iter().map(|ty| { + let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { + unreachable!() + }; + + (*id, *ty) + })); + Some((self_ty, type_vars.clone())) + } else { + None + } + }; + // carefully handle those with bounds, without bounds and no typevars + // if class methods, `vars` also contains all class typevars here + let (type_var_subst_comb, no_range_vars) = { + let mut no_ranges: Vec = Vec::new(); + let var_combs = vars + .values() + .map(|ty| { + unifier.get_instantiations(*ty).unwrap_or_else(|| { + let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else { + unreachable!() + }; + + let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; + no_ranges.push(rigid); + vec![rigid] + }) + }) + .multi_cartesian_product() + .collect_vec(); + let mut result: Vec> = Vec::default(); + for comb in var_combs { + result.push(vars.keys().copied().zip(comb).collect()); + } + // NOTE: if is empty, means no type var, append a empty subst, ok to do this? + if result.is_empty() { + result.push(HashMap::new()); + } + (result, no_ranges) + }; + + for subst in type_var_subst_comb { + // for each instance + let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); + let inst_args = { + args.iter() + .map(|a| FuncArg { + name: a.name, + ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), + default_value: a.default_value.clone(), + }) + .collect_vec() + }; + let self_type = { + uninst_self_type.clone().map(|(self_type, type_vars)| { + let subst_for_self = { + let class_ty_var_ids = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) + { + *id + } else { + unreachable!("must be type var here"); + } + }) + .collect::>(); + subst + .iter() + .filter_map(|(ty_var_id, ty_var_target)| { + if class_ty_var_ids.contains(ty_var_id) { + Some((*ty_var_id, *ty_var_target)) + } else { + None + } + }) + .collect::>() + }; + unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) + }) + }; + let mut identifiers = { + let mut result: HashSet<_> = HashSet::new(); + if self_type.is_some() { + result.insert("self".into()); + } + result.extend(inst_args.iter().map(|x| x.name)); + result + }; + let mut calls: HashMap = HashMap::new(); + let mut inferencer = Inferencer { + top_level: ctx.as_ref(), + defined_identifiers: identifiers.clone(), + function_data: &mut FunctionData { + resolver: resolver.as_ref().unwrap().clone(), + return_type: if unifier.unioned(inst_ret, primitives_ty.none) { + None + } else { + Some(inst_ret) + }, + // NOTE: allowed type vars + bound_variables: no_range_vars.clone(), + }, + unifier, + variable_mapping: { + let mut result: HashMap = HashMap::new(); + if let Some(self_ty) = self_type { + result.insert("self".into(), self_ty); + } + result.extend(inst_args.iter().map(|x| (x.name, x.ty))); + result + }, + primitives: primitives_ty, + virtual_checks: &mut Vec::new(), + calls: &mut calls, + in_handler: false, + }; + + let ast::StmtKind::FunctionDef { body, decorator_list, .. } = + ast.clone().unwrap().node else { + unreachable!("must be function def ast") + }; + if !decorator_list.is_empty() + && matches!(&decorator_list[0].node, + ast::ExprKind::Name{ id, .. } if id == &"extern".into()) + { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + if !decorator_list.is_empty() + && matches!(&decorator_list[0].node, + ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) + { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + + let fun_body = body + .into_iter() + .map(|b| inferencer.fold_stmt(b)) + .collect::, _>>()?; + + let returned = + inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; + { + // check virtuals + let defs = ctx.definitions.read(); + for (subtype, base, loc) in &*inferencer.virtual_checks { + let base_id = { + let base = inferencer.unifier.get_ty(*base); + if let TypeEnum::TObj { obj_id, .. } = &*base { + *obj_id + } else { + return Err(HashSet::from([ + format!("Base type should be a class (at {loc})"), + ])) + } + }; + let subtype_id = { + let ty = inferencer.unifier.get_ty(*subtype); + if let TypeEnum::TObj { obj_id, .. } = &*ty { + *obj_id + } else { + let base_repr = inferencer.unifier.stringify(*base); + let subtype_repr = inferencer.unifier.stringify(*subtype); + return Err(HashSet::from([format!( + "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), + ])) + } + }; + let subtype_entry = defs[subtype_id.0].read(); + let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else { + unreachable!() + }; + + let m = ancestors.iter() + .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); + if m.is_none() { + let base_repr = inferencer.unifier.stringify(*base); + let subtype_repr = inferencer.unifier.stringify(*subtype); + return Err(HashSet::from([format!( + "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), + ])) + } + } + } + if !unifier.unioned(inst_ret, primitives_ty.none) && !returned { + let def_ast_list = &definition_ast_list; + let ret_str = unifier.internal_stringify( + inst_ret, + &mut |id| { + let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() + else { unreachable!("must be class id here") }; + + name.to_string() + }, + &mut |id| format!("typevar{id}"), + &mut None, + ); + return Err(HashSet::from([format!( + "expected return type of `{}` in function `{}` (at {})", + ret_str, + name, + ast.as_ref().unwrap().location + ), + ])) + } + + instance_to_stmt.insert( + get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), + FunInstance { + body: Arc::new(fun_body), + unifier_id: 0, + calls: Arc::new(calls), + subst, + }, + ); } } + Ok(()) }; for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 1367543..e11de99 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -233,11 +233,11 @@ impl TopLevelComposer { }; // check cycle let no_cycle = result.iter().all(|x| { - if let TypeAnnotation::CustomClass { id, .. } = x { - id.0 != p_id.0 - } else { + let TypeAnnotation::CustomClass { id, .. } = x else { unreachable!("must be class kind annotation") - } + }; + + id.0 != p_id.0 }); if no_cycle { result.push(p); @@ -260,14 +260,14 @@ impl TopLevelComposer { }; let child_def = temp_def_list.get(child_id.0).unwrap(); let child_def = child_def.read(); - if let TopLevelDef::Class { ancestors, .. } = &*child_def { - if ancestors.is_empty() { - None - } else { - Some(ancestors[0].clone()) - } - } else { + let TopLevelDef::Class { ancestors, .. } = &*child_def else { unreachable!("child must be top level class def") + }; + + if ancestors.is_empty() { + None + } else { + Some(ancestors[0].clone()) } } @@ -292,39 +292,38 @@ impl TopLevelComposer { let this = this.as_ref(); let other = unifier.get_ty(other); let other = other.as_ref(); - if let ( + let ( TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), - ) = (this, other) - { - // check args - let args_ok = this_args - .iter() - .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) - .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { - (name, type_var_to_concrete_def.get(ty).unwrap()) - })) - .all(|(this, other)| { - if this.0 == &"self".into() && this.0 == other.0 { - true - } else { - this.0 == other.0 - && check_overload_type_annotation_compatible(this.1, other.1, unifier) - } - }); - - // check rets - let ret_ok = check_overload_type_annotation_compatible( - type_var_to_concrete_def.get(this_ret).unwrap(), - type_var_to_concrete_def.get(other_ret).unwrap(), - unifier, - ); - - // return - args_ok && ret_ok - } else { + ) = (this, other) else { unreachable!("this function must be called with function type") - } + }; + + // check args + let args_ok = this_args + .iter() + .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) + .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { + (name, type_var_to_concrete_def.get(ty).unwrap()) + })) + .all(|(this, other)| { + if this.0 == &"self".into() && this.0 == other.0 { + true + } else { + this.0 == other.0 + && check_overload_type_annotation_compatible(this.1, other.1, unifier) + } + }); + + // check rets + let ret_ok = check_overload_type_annotation_compatible( + type_var_to_concrete_def.get(this_ret).unwrap(), + type_var_to_concrete_def.get(other_ret).unwrap(), + unifier, + ); + + // return + args_ok && ret_ok } pub fn check_overload_field_type( diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 82b5bf9..a44de76 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -163,11 +163,11 @@ pub fn parse_ast_to_type_annotation_kinds( let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { + let TopLevelDef::Class { type_vars, .. } = &*def_read else { unreachable!("must be class here") - } + }; + + type_vars.clone() } else { locked.get(&obj_id).unwrap().clone() } @@ -497,13 +497,11 @@ pub fn get_type_from_type_annotation_kinds( TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Constant { ty, value, .. } => { let ty_enum = unifier.get_ty(*ty); - let (ty, loc) = match &*ty_enum { - TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => { - (ntv_underlying_ty[0], loc) - } - _ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()), + let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else { + unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()); }; + let ty = ntv_underlying_ty[0]; let var = unifier.get_fresh_constant(value.clone(), ty, *loc); Ok(var) } @@ -596,15 +594,14 @@ pub fn check_overload_type_annotation_compatible( let a = &*a; let b = unifier.get_ty(*b); let b = &*b; - if let ( + let ( TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. }, - ) = (a, b) - { - a == b - } else { + ) = (a, b) else { unreachable!("must be type var") - } + }; + + a == b } (TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) | (TypeAnnotation::List(a), TypeAnnotation::List(b)) => { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7cc5dd9..0b9b5a0 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -241,35 +241,35 @@ impl<'a> Fold<()> for Inferencer<'a> { let targets: Result, _> = targets .into_iter() .map(|target| { - if let ExprKind::Name { id, ctx } = target.node { - self.defined_identifiers.insert(id); - let target_ty = if let Some(ty) = self.variable_mapping.get(&id) - { - *ty - } else { - let unifier: &mut Unifier = self.unifier; - self.function_data - .resolver - .get_symbol_type( - unifier, - &self.top_level.definitions.read(), - self.primitives, - id, - ) - .unwrap_or_else(|_| { - self.variable_mapping.insert(id, value_ty); - value_ty - }) - }; - let location = target.location; - self.unifier.unify(value_ty, target_ty).map(|()| Located { - location, - node: ExprKind::Name { id, ctx }, - custom: Some(target_ty), - }) - } else { + let ExprKind::Name { id, ctx } = target.node else { unreachable!() - } + }; + + self.defined_identifiers.insert(id); + let target_ty = if let Some(ty) = self.variable_mapping.get(&id) + { + *ty + } else { + let unifier: &mut Unifier = self.unifier; + self.function_data + .resolver + .get_symbol_type( + unifier, + &self.top_level.definitions.read(), + self.primitives, + id, + ) + .unwrap_or_else(|_| { + self.variable_mapping.insert(id, value_ty); + value_ty + }) + }; + let location = target.location; + self.unifier.unify(value_ty, target_ty).map(|()| Located { + location, + node: ExprKind::Name { id, ctx }, + custom: Some(target_ty), + }) }) .collect(); let loc = node.location; @@ -465,12 +465,12 @@ impl<'a> Fold<()> for Inferencer<'a> { let var_map = params .iter() .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) { - assert_eq!(*id, *id_var); - (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) - } else { + let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) else { unreachable!() - } + }; + + assert_eq!(*id, *id_var); + (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) }) .collect::>(); Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 21d4fd2..85dfedc 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -499,12 +499,9 @@ impl Unifier { let instantiated = self.instantiate_fun(b, signature); let r = self.get_ty(instantiated); let r = r.as_ref(); - let signature; - if let TypeEnum::TFunc(s) = r { - signature = s; - } else { - unreachable!(); - } + let TypeEnum::TFunc(signature) = r else { + unreachable!() + }; // we check to make sure that all required arguments (those without default // arguments) are provided, and do not provide the same argument twice. let mut required = required.to_vec(); @@ -940,13 +937,12 @@ impl Unifier { top_level.as_ref().map_or_else( || format!("{id}"), |top_level| { - if let TopLevelDef::Class { name, .. } = - &*top_level.definitions.read()[id].read() - { - name.to_string() - } else { + let top_level_def = &top_level.definitions.read()[id]; + let TopLevelDef::Class { name, .. } = &*top_level_def.read() else { unreachable!("expected class definition") - } + }; + + name.to_string() }, ) }, diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 0762a10..3069b57 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -339,23 +339,21 @@ fn test_recursive_subst() { let int = *env.type_mapping.get("int").unwrap(); let foo_id = *env.type_mapping.get("Foo").unwrap(); let foo_ty = env.unifier.get_ty(foo_id); - let mapping: HashMap<_, _>; with_fields(&mut env.unifier, foo_id, |_unifier, fields| { fields.insert("rec".into(), (foo_id, true)); }); - if let TypeEnum::TObj { params, .. } = &*foo_ty { - mapping = params.iter().map(|(id, _)| (*id, int)).collect(); - } else { + let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() - } + }; + let mapping = params.iter().map(|(id, _)| (*id, int)).collect(); let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated_ty = env.unifier.get_ty(instantiated); - if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { - assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); - assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); - } else { + + let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() - } + }; + assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); + assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); } #[test] diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 3ea7238..f34f55a 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -363,12 +363,11 @@ fn main() { .unwrap_or_else(|_| panic!("cannot find run() entry point")) .0] .write(); - if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance { - instance_to_symbol.insert(String::new(), "run".to_string()); - instance_to_stmt[""].clone() - } else { + let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else { unreachable!() - } + }; + instance_to_symbol.insert(String::new(), "run".to_string()); + instance_to_stmt[""].clone() }; let llvm_options = CodeGenLLVMOptions {