forked from M-Labs/nac3
1
0
Fork 0

core: irrt ErrorContext

This commit is contained in:
lyken 2024-07-13 23:58:29 +08:00
parent 9e78139373
commit 8863cd64a9
16 changed files with 463 additions and 49 deletions

View File

@ -163,7 +163,10 @@
clippy
pre-commit
rustfmt
rust-analyzer
];
# https://nixos.wiki/wiki/Rust#Shell.nix_example
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -0,0 +1,69 @@
#pragma once
#include "int_defs.hpp"
#include "utils.hpp"
namespace {
// nac3core's "str" struct type definition
template <typename SizeT>
struct Str {
const char* content;
SizeT length;
};
struct ErrorContext {
const char* message_template; // MUST BE `&'static`
uint64_t param1;
uint64_t param2;
uint64_t param3;
void initialize() {
clear_error();
}
void clear_error() {
// Point the message_template to an empty str. Don't set it to nullptr as a sentinel
set_error("");
}
void set_error(const char* message, uint64_t param1 = 0, uint64_t param2 = 0, uint64_t param3 = 0) {
this->message_template = message;
this->param1 = param1;
this->param2 = param2;
this->param3 = param3;
}
bool has_error() {
return !cstr_utils::is_empty(message_template);
}
template <typename SizeT>
void set_error_str(Str<SizeT> *dst_str) {
dst_str->content = message_template;
dst_str->length = (SizeT) cstr_utils::length(message_template);
}
};
}
extern "C" {
void __nac3_error_context_initialize(ErrorContext* errctx) {
errctx->initialize();
}
uint8_t __nac3_error_context_has_no_error(ErrorContext* errctx) {
return !errctx->has_error();
}
void __nac3_error_context_get_error_str(ErrorContext* errctx, Str<int32_t> *dst_str) {
errctx->set_error_str<int32_t>(dst_str);
}
void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str<int64_t> *dst_str) {
errctx->set_error_str<int64_t>(dst_str);
}
void __nac3_error_dummy_raise(ErrorContext* errctx) {
errctx->set_error("THROWN FROM __nac3_error_dummy_raise!!!!!!");
}
}

View File

@ -0,0 +1,60 @@
#pragma once
#include "int_defs.hpp"
namespace {
namespace string {
bool is_empty(const char* str) {
return str[0] == '\0';
}
int8_t compare(const char* a, const char* b) {
uint32_t i = 0;
while (true) {
if (a[i] < b[i]) {
return -1;
} else if (a[i] > b[i]) {
return 1;
} else { // a[i] == b[i]
if (a[i] == '\0') {
return 0;
} else {
continue;
}
}
}
}
int8_t equal(const char* a, const char* b) {
return compare(a, b) == 0;
}
uint32_t length(const char* str) {
uint32_t length = 0;
while (*str != '\0') {
length++;
str++;
}
return length;
}
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src[i] != '\0') {
dst[i] = '\0';
return false;
}
if (src[i] == '\0') {
dst[i] = '\0';
return true;
}
dst[i] = src[i];
}
__builtin_unreachable();
}
}
}

View File

@ -1,5 +1,7 @@
#pragma once
#include "int_defs.hpp"
namespace {
template <typename T>
const T& max(const T& a, const T& b) {
@ -18,4 +20,69 @@ bool arrays_match(int len, T* as, T* bs) {
}
return true;
}
template<typename T>
uint32_t int_log_floor(T value, T base) {
uint32_t result = 0;
while (value >= base) {
result++;
value /= base;
}
return result;
}
namespace cstr_utils {
bool is_empty(const char* str) {
return str[0] == '\0';
}
int8_t compare(const char* a, const char* b) {
uint32_t i = 0;
while (true) {
if (a[i] < b[i]) {
return -1;
} else if (a[i] > b[i]) {
return 1;
} else { // a[i] == b[i]
if (a[i] == '\0') {
return 0;
} else {
i++;
}
}
}
}
int8_t equal(const char* a, const char* b) {
return compare(a, b) == 0;
}
uint32_t length(const char* str) {
uint32_t length = 0;
while (*str != '\0') {
length++;
str++;
}
return length;
}
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src[i] != '\0') {
dst[i] = '\0';
return false;
}
if (src[i] == '\0') {
dst[i] = '\0';
return true;
}
dst[i] = src[i];
}
__builtin_unreachable();
}
}
}

View File

@ -1,3 +1,6 @@
#pragma once
#include "irrt/core.hpp"
#include "irrt/error_context.hpp"
#include "irrt/int_defs.hpp"
#include "irrt/utils.hpp"

View File

@ -9,8 +9,11 @@
#include "test/core.hpp"
#include "test/test_core.hpp"
#include "test/test_utils.hpp"
int main() {
test_int_exp();
run_test_core();
run_test_print();
run_test_utils();
return 0;
}

View File

@ -4,39 +4,39 @@
#include <cstdio>
template <class T>
void print_value(const T& value) {}
void print_value(T value);
template <>
void print_value(const int8_t& value) {
void print_value(char value) {
printf("'%c' (ord=%d)", value, value);
}
template <>
void print_value(int8_t value) {
printf("%d", value);
}
template <>
void print_value(const int32_t& value) {
void print_value(int32_t value) {
printf("%d", value);
}
template <>
void print_value(const uint8_t& value) {
void print_value(uint8_t value) {
printf("%u", value);
}
template <>
void print_value(const uint32_t& value) {
void print_value(uint32_t value) {
printf("%u", value);
}
template <>
void print_value(const double& value) {
void print_value(double value) {
printf("%f", value);
}
// template <double>
// void print_value(const double& value) {
// printf("%f", value);
// }
//
// template <char *>
// void print_value(const char*& value) {
// printf("%f", value);
// }
template <>
void print_value(char* value) {
printf("%s", value);
}

View File

@ -8,4 +8,8 @@ void test_int_exp() {
assert_values_match(125, __nac3_int_exp_impl<int32_t>(5, 3));
assert_values_match(3125, __nac3_int_exp_impl<int32_t>(5, 5));
}
void run_test_core() {
test_int_exp();
}

View File

@ -0,0 +1,27 @@
#pragma once
#include "core.hpp"
#include "../irrt/utils.hpp"
void test_int_log_10() {
BEGIN_TEST();
assert_values_match((uint32_t) 0, int_log_floor(0, 10));
assert_values_match((uint32_t) 0, int_log_floor(9, 10));
assert_values_match((uint32_t) 1, int_log_floor(10, 10));
assert_values_match((uint32_t) 1, int_log_floor(11, 10));
assert_values_match((uint32_t) 1, int_log_floor(99, 10));
assert_values_match((uint32_t) 2, int_log_floor(100, 10));
assert_values_match((uint32_t) 2, int_log_floor(101, 10));
}
void test_cstr_utils() {
BEGIN_TEST();
assert_values_match((uint32_t) 42, (uint32_t) cstr_utils::length("THROWN FROM __nac3_error_dummy_raise!!!!!!"));
}
void run_test_utils() {
test_int_log_10();
test_cstr_utils();
}

View File

@ -1,10 +1,32 @@
use inkwell::types::{BasicTypeEnum, IntType};
use crate::codegen::optics::{AddressLens, GepGetter, IntLens, StructureOptic};
use crate::codegen::optics::{AddressLens, FieldBuilder, GepGetter, IntLens, StructureOptic};
// use crate::codegen::structure::{
// FieldLensBuilder, IntLens, LensWithFieldInfo, PointerLens, StructFieldLens,
// };
#[derive(Debug, Clone)]
pub struct StrLens<'ctx> {
pub size_type: IntType<'ctx>,
}
// TODO: nac3core has hardcoded a lot of "str"
pub struct StrFields<'ctx> {
pub content: GepGetter<AddressLens<IntLens<'ctx>>>,
pub length: GepGetter<IntLens<'ctx>>,
}
impl<'ctx> StructureOptic<'ctx> for StrLens<'ctx> {
type Fields = StrFields<'ctx>;
fn struct_name(&self) -> &'static str {
"str"
}
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
StrFields {
content: builder.add_field("content", AddressLens(IntLens(builder.ctx.i8_type()))),
length: builder.add_field("length", IntLens(self.size_type)),
}
}
}
pub struct NpArrayFields<'ctx> {
pub data: GepGetter<AddressLens<IntLens<'ctx>>>,
@ -54,7 +76,7 @@ impl<'ctx> StructureOptic<'ctx> for IrrtStringLens {
type Fields = IrrtStringFields<'ctx>;
fn struct_name(&self) -> &'static str {
todo!()
"String"
}
fn build_fields(
@ -71,15 +93,18 @@ impl<'ctx> StructureOptic<'ctx> for IrrtStringLens {
}
}
pub struct ErrorContextFields {
pub message: GepGetter<IrrtStringLens>,
pub struct ErrorContextFields<'ctx> {
pub message_template: GepGetter<AddressLens<IntLens<'ctx>>>,
pub param1: GepGetter<IntLens<'ctx>>,
pub param2: GepGetter<IntLens<'ctx>>,
pub param3: GepGetter<IntLens<'ctx>>,
}
#[derive(Debug, Clone, Copy)]
pub struct ErrorContextLens;
impl<'ctx> StructureOptic<'ctx> for ErrorContextLens {
type Fields = ErrorContextFields;
type Fields = ErrorContextFields<'ctx>;
fn struct_name(&self) -> &'static str {
"ErrorContext"
@ -89,6 +114,12 @@ impl<'ctx> StructureOptic<'ctx> for ErrorContextLens {
&self,
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
) -> Self::Fields {
ErrorContextFields { message: builder.add_field("message", IrrtStringLens) }
ErrorContextFields {
message_template: builder
.add_field("message_template", AddressLens(IntLens(builder.ctx.i8_type()))),
param1: builder.add_field("param1", IntLens(builder.ctx.i64_type())),
param2: builder.add_field("param2", IntLens(builder.ctx.i64_type())),
param3: builder.add_field("param3", IntLens(builder.ctx.i64_type())),
}
}
}

View File

@ -23,6 +23,9 @@ use inkwell::{
use itertools::Either;
use nac3parser::ast::Expr;
pub mod classes;
pub mod new;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(

View File

@ -0,0 +1,154 @@
use inkwell::{
types::{BasicMetadataTypeEnum, BasicType, IntType},
values::{AnyValue, BasicMetadataValueEnum, IntValue},
};
use crate::{
codegen::{
optics::{Address, AddressLens, IntLens, Optic, OpticValue, Prism},
CodeGenContext, CodeGenerator,
},
util::SizeVariant,
};
use super::classes::{ErrorContextLens, StrLens};
fn get_size_variant(ty: IntType) -> SizeVariant {
match ty.get_bit_width() {
32 => SizeVariant::Bits32,
64 => SizeVariant::Bits64,
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
}
}
fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String {
let mut fn_name = fn_name.to_owned();
match get_size_variant(ty) {
SizeVariant::Bits32 => {
// Do nothing, `fn_name` already has the correct name
}
SizeVariant::Bits64 => {
// Append "64", this is the naming convention
fn_name.push_str("64");
}
}
fn_name
}
// TODO: Variadic argument?
pub struct FunctionBuilder<'ctx, 'a> {
ctx: &'a CodeGenContext<'ctx, 'a>,
fn_name: &'a str,
arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>,
}
impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self {
FunctionBuilder { ctx, fn_name, arguments: Vec::new() }
}
// The name is for self-documentation
#[must_use]
pub fn arg<S: Optic<'ctx>>(mut self, _name: &'static str, optic: &S, arg: &S::Value) -> Self {
self.arguments
.push((optic.get_llvm_type(self.ctx.ctx).into(), arg.get_llvm_value().into()));
self
}
pub fn returning<S: Prism<'ctx>>(self, name: &'static str, return_prism: &S) -> S::Value {
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
let return_type = return_prism.get_llvm_type(self.ctx.ctx);
let fn_type = return_type.fn_type(&param_tys, false);
self.ctx.module.add_function(self.fn_name, fn_type, None)
});
let ret = self.ctx.builder.build_call(function, &param_vals, name).unwrap();
return_prism.review(ret.as_any_value_enum())
}
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
pub fn returning_void(self) {
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
let return_type = self.ctx.ctx.void_type();
let fn_type = return_type.fn_type(&param_tys, false);
self.ctx.module.add_function(self.fn_name, fn_type, None)
});
self.ctx.builder.build_call(function, &param_vals, "").unwrap();
}
}
pub fn call_nac3_error_context_initialize<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
errctx: &Address<'ctx, ErrorContextLens>,
) {
FunctionBuilder::begin(ctx, "__nac3_error_context_initialize")
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
.returning_void();
}
pub fn call_nac3_error_context_has_no_error<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
errctx: &Address<'ctx, ErrorContextLens>,
) -> IntValue<'ctx> {
FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error")
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
.returning("has_error", &IntLens(ctx.ctx.bool_type()))
}
pub fn call_nac3_error_context_get_error_str<'ctx>(
size_type: IntType<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
errctx: &Address<'ctx, ErrorContextLens>,
dst_str: &Address<'ctx, StrLens<'ctx>>,
) -> IntValue<'ctx> {
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(size_type, "__nac3_error_context_get_error_str"),
)
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
.arg("dst_str", &AddressLens(StrLens { size_type }), dst_str)
.returning("has_error", &IntLens(ctx.ctx.bool_type()))
}
pub fn call_nac3_dummy_raise<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
errctx: &Address<'ctx, ErrorContextLens>,
) {
FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise")
.arg("errctx", &AddressLens(ErrorContextLens), errctx)
.returning_void();
}
pub fn test_dummy_raise<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'_, '_>,
) {
let size_type = generator.get_size_type(ctx.ctx);
let errctx_ptr = ErrorContextLens.alloca(ctx, "errctx");
call_nac3_error_context_initialize(ctx, &errctx_ptr);
call_nac3_dummy_raise(ctx, &errctx_ptr);
let has_error = call_nac3_error_context_has_no_error(ctx, &errctx_ptr);
let error_str_ptr = StrLens { size_type }.alloca(ctx, "error_str");
call_nac3_error_context_get_error_str(size_type, ctx, &errctx_ptr, &error_str_ptr);
let error_str = error_str_ptr.load(ctx, "error_str");
let param1 = errctx_ptr.view(ctx, |fields| &fields.param1).load(ctx, "param1");
let param2 = errctx_ptr.view(ctx, |fields| &fields.param2).load(ctx, "param2");
let param3 = errctx_ptr.view(ctx, |fields| &fields.param3).load(ctx, "param3");
ctx.make_assert_impl(
generator,
has_error,
"0:RuntimeError", // TODO: Make this dynamic (within IRRT), but this is probably not trivial
error_str.get_llvm_value(),
[Some(param1), Some(param2), Some(param3)],
ctx.current_loc,
);
}

View File

@ -11,7 +11,6 @@ mod tests {
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");

View File

@ -23,8 +23,10 @@ use inkwell::{
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use irrt::classes::StrLens;
use itertools::Itertools;
use nac3parser::ast::{Location, Stmt, StrRef};
use optics::Optic;
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
@ -647,6 +649,8 @@ pub fn gen_func_impl<
..primitives
};
let llvm_str_ty =
StrLens { size_type: generator.get_size_type(context) }.get_llvm_type(context);
let mut type_cache: HashMap<_, _> = [
(primitives.int32, context.i32_type().into()),
(primitives.int64, context.i64_type().into()),
@ -654,21 +658,7 @@ pub fn gen_func_impl<
(primitives.uint64, context.i64_type().into()),
(primitives.float, context.f64_type().into()),
(primitives.bool, context.i8_type().into()),
(primitives.str, {
let name = "str";
match module.get_struct_type(name) {
None => {
let str_type = context.opaque_struct_type("str");
let fields = [
context.i8_type().ptr_type(AddressSpace::default()).into(),
generator.get_size_type(context).into(),
];
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum(),
}
}),
(primitives.str, llvm_str_ty),
(primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, {
let name = "Exception";
@ -678,7 +668,7 @@ pub fn gen_func_impl<
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let str_ty = llvm_str_ty;
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()

View File

@ -173,7 +173,7 @@ impl<'ctx, AddresseeOptic: MemoryGetter<'ctx>> Address<'ctx, AddresseeOptic> {
// To make [`Address`] convenient to use
impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> {
pub fn set(&self, ctx: &CodeGenContext<'ctx, '_>, value: &AddresseeOptic::Value) {
self.addressee_optic.set(ctx, self.address, value)
self.addressee_optic.set(ctx, self.address, value);
}
}
@ -181,7 +181,7 @@ impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> {
#[derive(Debug, Clone)]
pub struct GepGetter<ElementOptic> {
/// The LLVM GEP index
pub gep_index: u32, // TODO: I think I'm not supposed to *just* use i32 for GEP like that
pub gep_index: u64,
/// Element (or field in the context of `struct`s) name. Used for cosmetics.
pub name: &'static str,
/// The lens to view the actual value after applying this [`FieldLens<T>`]
@ -208,7 +208,7 @@ impl<'ctx, ElementOptic: Optic<'ctx>> MemoryGetter<'ctx> for GepGetter<ElementOp
ctx.builder
.build_in_bounds_gep(
pointer,
&[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index as u64, false)],
&[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index, false)],
name,
)
.unwrap()
@ -220,7 +220,7 @@ impl<'ctx, ElementOptic: Optic<'ctx>> MemoryGetter<'ctx> for GepGetter<ElementOp
// Only used by [`FieldBuilder`]
#[derive(Debug)]
struct FieldInfo<'ctx> {
gep_index: u32,
gep_index: u64,
name: &'ctx str,
llvm_type: BasicTypeEnum<'ctx>,
}
@ -228,17 +228,18 @@ struct FieldInfo<'ctx> {
#[derive(Debug)]
pub struct FieldBuilder<'ctx> {
pub ctx: &'ctx Context,
gep_index_counter: u32,
gep_index_counter: u64,
struct_name: &'ctx str,
fields: Vec<FieldInfo<'ctx>>,
}
impl<'ctx> FieldBuilder<'ctx> {
#[must_use]
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() }
}
fn next_gep_index(&mut self) -> u32 {
fn next_gep_index(&mut self) -> u64 {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
index

View File

@ -1,5 +1,6 @@
use std::iter::once;
use crate::util::SizeVariant;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap;
use inkwell::{
@ -23,7 +24,6 @@ use crate::{
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
util::SizeVariant,
};
use super::*;