From eac164ce115129ec2e9e2acfcc604397c98e0db1 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 18 Jul 2024 16:07:47 +0800 Subject: [PATCH] WIP --- nac3core/irrt/irrt/error_context.hpp | 87 ++++++++++ nac3core/irrt/irrt_everything.hpp | 3 +- nac3core/src/codegen/expr.rs | 44 ++++- nac3core/src/codegen/irrt/error_context.rs | 185 +++++++++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 3 + nac3core/src/codegen/irrt/string.rs | 34 ++++ nac3core/src/codegen/irrt/util.rs | 79 +++++++++ nac3core/src/codegen/mod.rs | 17 +- nac3core/src/codegen/numpy.rs | 4 +- 9 files changed, 438 insertions(+), 18 deletions(-) create mode 100644 nac3core/irrt/irrt/error_context.hpp create mode 100644 nac3core/src/codegen/irrt/error_context.rs create mode 100644 nac3core/src/codegen/irrt/string.rs create mode 100644 nac3core/src/codegen/irrt/util.rs diff --git a/nac3core/irrt/irrt/error_context.hpp b/nac3core/irrt/irrt/error_context.hpp new file mode 100644 index 00000000..57819655 --- /dev/null +++ b/nac3core/irrt/irrt/error_context.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include + +namespace { +// nac3core's "str" struct type definition +template +struct Str { + const char* content; + SizeT length; +}; + +// A limited set of errors IRRT could use. +typedef uint32_t ErrorId; +struct ErrorIds { + ErrorId index_error; + ErrorId value_error; + ErrorId assertion_error; + ErrorId runtime_error; + ErrorId type_error; +}; + +struct ErrorContext { + // Context + ErrorIds* error_ids; + + // Error thrown by IRRT + ErrorId error_id; + const char* message_template; // MUST BE `&'static` + uint64_t param1; + uint64_t param2; + uint64_t param3; + + void initialize(ErrorIds* error_ids) { + this->error_ids = error_ids; + clear_error(); + } + + void clear_error() { + // Point the message_template to an empty str. Don't set it to nullptr as a sentinel + this->message_template = ""; + } + + void + set_error(ErrorId error_id, const char* message, uint64_t param1 = 0, uint64_t param2 = 0, uint64_t param3 = 0) { + this->error_id = error_id; + this->message_template = message; + this->param1 = param1; + this->param2 = param2; + this->param3 = param3; + } + + bool has_error() { return !cstr_utils::is_empty(message_template); } + + /// Get a nac3core-understanding `Str` that containing + /// the error message template + template + void get_error_str(Str* dst_str) { + dst_str->content = message_template; + dst_str->length = (SizeT)cstr_utils::length(message_template); + } +}; +} // namespace + +extern "C" { +void __nac3_error_context_initialize(ErrorContext* errctx, ErrorIds* error_ids) { + errctx->initialize(error_ids); +} + +bool __nac3_error_context_has_no_error(ErrorContext* errctx) { + return !errctx->has_error(); +} + +void __nac3_error_context_get_error_str(ErrorContext* errctx, Str* dst_str) { + errctx->get_error_str(dst_str); +} + +void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str* dst_str) { + errctx->get_error_str(dst_str); +} + +// Used for testing +void __nac3_error_dummy_raise(ErrorContext* errctx) { + errctx->set_error(errctx->error_ids->runtime_error, "Error thrown from __nac3_error_dummy_raise"); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index a1c45e1e..d0480a5b 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -2,4 +2,5 @@ #include #include -#include \ No newline at end of file +#include +#include \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c42c8444..d3571a52 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -576,6 +576,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { params: [Option>; 3], loc: Location, ) { + let error_id = self.resolver.get_string_id(name); + let error_id = self.ctx.i32_type().const_int(error_id as u64, false); + self.raise_exn_by_id(generator, error_id, msg, params, loc); + } + + pub fn raise_exn_by_id( + &mut self, + generator: &mut G, + error_id: IntValue<'ctx>, + msg: BasicValueEnum<'ctx>, + params: [Option>; 3], + loc: Location, + ) { + assert_eq!(error_id.get_type().get_bit_width(), 32); + let zelf = if let Some(exception_val) = self.exception_val { exception_val } else { @@ -587,9 +602,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let int32 = self.ctx.i32_type(); let zero = int32.const_zero(); unsafe { - let id_ptr = self.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); - let id = self.resolver.get_string_id(name); - self.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); let ptr = self .builder .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") @@ -652,6 +664,32 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.raise_exn(generator, err_name, err_msg, params, loc); self.builder.position_at_end(then_block); } + + pub fn make_assert_impl_by_id( + &mut self, + generator: &mut G, + cond: IntValue<'ctx>, + err_id: IntValue<'ctx>, + err_msg: BasicValueEnum<'ctx>, + params: [Option>; 3], + loc: Location, + ) { + let i1 = self.ctx.bool_type(); + let i1_true = i1.const_all_ones(); + // we assume that the condition is most probably true, so the normal path is the most + // probable path + // even if this assumption is violated, it does not matter as exception unwinding is + // slow anyway... + let cond = call_expect(self, cond, i1_true, Some("expect")); + let current_bb = self.builder.get_insert_block().unwrap(); + let current_fun = current_bb.get_parent().unwrap(); + let then_block = self.ctx.insert_basic_block_after(current_bb, "succ"); + let exn_block = self.ctx.append_basic_block(current_fun, "fail"); + self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap(); + self.builder.position_at_end(exn_block); + self.raise_exn_by_id(generator, err_id, err_msg, params, loc); + self.builder.position_at_end(then_block); + } } /// See [`CodeGenerator::gen_constructor`]. diff --git a/nac3core/src/codegen/irrt/error_context.rs b/nac3core/src/codegen/irrt/error_context.rs new file mode 100644 index 00000000..a1a4bf48 --- /dev/null +++ b/nac3core/src/codegen/irrt/error_context.rs @@ -0,0 +1,185 @@ +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +use super::{string::Str, util::get_sized_dependent_function_name}; + +/// The [`IntModel`] of nac3core's error ID. +/// +/// It is always [`Int32`]. +type ErrorId = Int32; + +pub struct ErrorIdsFields { + pub index_error: Field>, + pub value_error: Field>, + pub assertion_error: Field>, + pub runtime_error: Field>, + pub type_error: Field>, +} + +/// Corresponds to IRRT's `struct ErrorIds` +#[derive(Debug, Clone, Copy, Default)] +pub struct ErrorIds; + +impl<'ctx> IsStruct<'ctx> for ErrorIds { + type Fields = ErrorIdsFields; + + fn struct_name(&self) -> &'static str { + "ErrorIds" + } + + fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields { + Self::Fields { + index_error: builder.add_field_auto("index_error"), + value_error: builder.add_field_auto("value_error"), + assertion_error: builder.add_field_auto("assertion_error"), + runtime_error: builder.add_field_auto("runtime_error"), + type_error: builder.add_field_auto("type_error"), + } + } +} + +pub struct ErrorContextFields { + pub error_ids: Field>>, + pub error_id: Field>, + pub message_template: Field>>, + pub param1: Field>, + pub param2: Field>, + pub param3: Field>, +} + +/// Corresponds to IRRT's `struct ErrorContext` +#[derive(Debug, Clone, Copy, Default)] +pub struct ErrorContext; + +impl<'ctx> IsStruct<'ctx> for ErrorContext { + type Fields = ErrorContextFields; + + fn struct_name(&self) -> &'static str { + "ErrorIds" + } + + fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields { + Self::Fields { + error_ids: builder.add_field_auto("error_ids"), + error_id: builder.add_field_auto("error_id"), + message_template: builder.add_field_auto("message_template"), + param1: builder.add_field_auto("param1"), + param2: builder.add_field_auto("param2"), + param3: builder.add_field_auto("param3"), + } + } +} + +// Prepare ErrorIds +fn build_error_ids<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> Pointer<'ctx, StructModel> { + // ErrorIdsLens.get_fields(ctx.ctx).assertion_error. + let error_ids = StructModel(ErrorIds).alloca(ctx, "error_ids"); + let i32_model = FixedIntModel(Int32); + // i32_model.make_constant() + + let get_string_id = + |string_id| i32_model.constant(ctx.ctx, ctx.resolver.get_string_id(string_id) as u64); + + error_ids.gep(ctx, |f| f.index_error).store(ctx, get_string_id("0:IndexError")); + error_ids.gep(ctx, |f| f.value_error).store(ctx, get_string_id("0:ValueError")); + error_ids.gep(ctx, |f| f.assertion_error).store(ctx, get_string_id("0:AssertionError")); + error_ids.gep(ctx, |f| f.runtime_error).store(ctx, get_string_id("0:RuntimeError")); + error_ids.gep(ctx, |f| f.type_error).store(ctx, get_string_id("0:TypeError")); + + error_ids +} + +pub fn call_nac3_error_context_initialize<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + perrctx: Pointer<'ctx, StructModel>, + perror_ids: Pointer<'ctx, StructModel>, +) { + FunctionBuilder::begin(ctx, "__nac3_error_context_initialize") + .arg("errctx", PointerModel(StructModel(ErrorContext)), perrctx) + .arg("error_ids", PointerModel(StructModel(ErrorIds)), perror_ids) + .returning_void(); +} + +pub fn call_nac3_error_context_has_no_error<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + errctx: Pointer<'ctx, StructModel>, +) -> FixedInt<'ctx, Bool> { + FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error") + .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx) + .returning("has_error", FixedIntModel(Bool)) +} + +pub fn call_nac3_error_context_get_error_str<'ctx>( + sizet: IntModel<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + errctx: Pointer<'ctx, StructModel>, + dst_str: Pointer<'ctx, StructModel>>, +) { + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name(sizet, "__nac3_error_context_get_error_str"), + ) + .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx) + .arg("dst_str", PointerModel(StructModel(Str { sizet })), dst_str) + .returning_void(); +} + +/// Setup a [`ErrorContext`] that could +/// be passed to IRRT functions taking in a `ErrorContext* errctx` +/// for error reporting purposes. +/// +/// Also see: [`check_error_context`] +pub fn setup_error_context<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, +) -> Pointer<'ctx, StructModel> { + let error_ids = build_error_ids(ctx); + let errctx_ptr = StructModel(ErrorContext).alloca(ctx, "errctx"); + call_nac3_error_context_initialize(ctx, errctx_ptr, error_ids); + errctx_ptr +} + +/// Check a [`ErrorContext`] to see +/// if it contains error. +/// +/// If there is an error, an LLVM exception will be raised at runtime. +pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + errctx_ptr: Pointer<'ctx, StructModel>, +) { + let sizet = IntModel(generator.get_size_type(ctx.ctx)); + + // Does ErrorContext contain an error? + let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr); + + // Get the error message (doesn't matter even if there's actually no error) + let pstr = StructModel(Str { sizet }).alloca(ctx, "error_str"); + call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, pstr); + + // Load all the values for `ctx.make_assert_impl_by_id` + let error_id = errctx_ptr.gep(ctx, |f| f.error_id).load(ctx, "error_id"); + let error_str = pstr.load(ctx, "error_str"); + let param1 = errctx_ptr.gep(ctx, |f| f.param1).load(ctx, "param1"); + let param2 = errctx_ptr.gep(ctx, |f| f.param2).load(ctx, "param2"); + let param3 = errctx_ptr.gep(ctx, |f| f.param3).load(ctx, "param3"); + + // Make assert + ctx.make_assert_impl_by_id( + generator, + has_error.value, + error_id.value, + error_str.get_llvm_value(), + [Some(param1.value), Some(param2.value), Some(param3.value)], + ctx.current_loc, + ); +} + +pub fn call_nac3_dummy_raise( + generator: &mut G, + ctx: &mut CodeGenContext, +) { + let errctx = setup_error_context(ctx); + FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise") + .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx) + .returning_void(); + check_error_context(generator, ctx, errctx); +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index dfb91611..90f88f2e 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,9 @@ use crate::typecheck::typedef::Type; +pub mod error_context; +pub mod string; mod test; +pub mod util; use super::{ classes::{ diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 00000000..867287ac --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,34 @@ +use crate::codegen::model::*; + +pub struct StrFields<'ctx> { + /// Pointer to the string. Does not have to be null-terminated. + pub content: Field>>, + /// Number of bytes this string occupies in space. + /// + /// The [`IntModel`] matches [`Str::sizet`]. + pub length: Field>, +} + +/// Corresponds to IRRT's `struct Str` +/// +/// nac3core's LLVM representation of a string in memory +#[derive(Debug, Clone, Copy)] +pub struct Str<'ctx> { + /// The `SizeT` type of this string. + pub sizet: IntModel<'ctx>, +} + +impl<'ctx> IsStruct<'ctx> for Str<'ctx> { + type Fields = StrFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "Str" + } + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + Self::Fields { + content: builder.add_field_auto("content"), + length: builder.add_field("length", self.sizet), + } + } +} diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs new file mode 100644 index 00000000..9b0e3cf0 --- /dev/null +++ b/nac3core/src/codegen/irrt/util.rs @@ -0,0 +1,79 @@ +use inkwell::{ + types::{BasicMetadataTypeEnum, BasicType, IntType}, + values::{AnyValue, BasicMetadataValueEnum}, +}; + +use crate::{ + codegen::{model::*, CodeGenContext}, + util::SizeVariant, +}; + +fn get_size_variant(ty: IntType) -> SizeVariant { + match ty.get_bit_width() { + 32 => SizeVariant::Bits32, + 64 => SizeVariant::Bits64, + _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), + } +} + +#[must_use] +pub fn get_sized_dependent_function_name(ty: IntModel, fn_name: &str) -> String { + let mut fn_name = fn_name.to_owned(); + match get_size_variant(ty.0) { + SizeVariant::Bits32 => { + // Do nothing, `fn_name` already has the correct name + } + SizeVariant::Bits64 => { + // Append "64", this is the naming convention + fn_name.push_str("64"); + } + } + fn_name +} + +// TODO: Variadic argument? +pub struct FunctionBuilder<'ctx, 'a> { + ctx: &'a CodeGenContext<'ctx, 'a>, + fn_name: &'a str, + arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>, +} + +impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { + pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self { + FunctionBuilder { ctx, fn_name, arguments: Vec::new() } + } + + // The name is for self-documentation + #[must_use] + pub fn arg>(mut self, _name: &'static str, model: M, value: M::Value) -> Self { + self.arguments + .push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into())); + self + } + + pub fn returning>(self, name: &'static str, return_model: M) -> M::Value { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = return_model.get_llvm_type(self.ctx.ctx); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); + return_model.review(self.ctx.ctx, ret.as_any_value_enum()) + } + + // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None + pub fn returning_void(self) { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = self.ctx.ctx.void_type(); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + self.ctx.builder.build_call(function, ¶m_vals, "").unwrap(); + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 85b963bb..2b67ec4e 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -23,7 +23,9 @@ use inkwell::{ values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; +use irrt::string::Str; use itertools::Itertools; +use model::*; use nac3parser::ast::{Location, Stmt, StrRef}; use parking_lot::{Condvar, Mutex}; use std::collections::{HashMap, HashSet}; @@ -655,19 +657,8 @@ pub fn gen_func_impl< (primitives.float, context.f64_type().into()), (primitives.bool, context.i8_type().into()), (primitives.str, { - let name = "str"; - match module.get_struct_type(name) { - None => { - let str_type = context.opaque_struct_type("str"); - let fields = [ - context.i8_type().ptr_type(AddressSpace::default()).into(), - generator.get_size_type(context).into(), - ]; - str_type.set_body(&fields, false); - str_type.into() - } - Some(t) => t.as_basic_type_enum(), - } + let sizet = IntModel(generator.get_size_type(context)); + StructModel(Str { sizet }).get_llvm_type(context) }), (primitives.range, RangeType::new(context).as_base_type().into()), (primitives.exception, { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7421c894..bce93baa 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -9,7 +9,7 @@ use crate::{ irrt::{ calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, - call_ndarray_calc_size, + call_ndarray_calc_size, error_context::call_nac3_dummy_raise, }, llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, @@ -1742,6 +1742,8 @@ pub fn gen_ndarray_zeros<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + call_nac3_dummy_raise(generator, context); + let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;