From 709844b85507dbe19d17813f72e87ea9c4060228 Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 14 Jul 2024 18:50:07 +0800 Subject: [PATCH] core: inkwell model abstraction --- flake.nix | 3 + nac3core/src/codegen/irrt/error_context.rs | 81 ++++++++++++ nac3core/src/codegen/irrt/util.rs | 79 ++++++++++++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/model/core.rs | 37 ++++++ nac3core/src/codegen/model/gep.rs | 138 +++++++++++++++++++++ nac3core/src/codegen/model/int.rs | 121 ++++++++++++++++++ nac3core/src/codegen/model/mod.rs | 11 ++ nac3core/src/codegen/model/pointer.rs | 74 +++++++++++ nac3core/src/codegen/model/slice.rs | 75 +++++++++++ 10 files changed, 620 insertions(+) create mode 100644 nac3core/src/codegen/irrt/error_context.rs create mode 100644 nac3core/src/codegen/irrt/util.rs create mode 100644 nac3core/src/codegen/model/core.rs create mode 100644 nac3core/src/codegen/model/gep.rs create mode 100644 nac3core/src/codegen/model/int.rs create mode 100644 nac3core/src/codegen/model/mod.rs create mode 100644 nac3core/src/codegen/model/pointer.rs create mode 100644 nac3core/src/codegen/model/slice.rs diff --git a/flake.nix b/flake.nix index a6ce5fce..79b1fa63 100644 --- a/flake.nix +++ b/flake.nix @@ -163,7 +163,10 @@ clippy pre-commit rustfmt + rust-analyzer ]; + # https://nixos.wiki/wiki/Rust#Shell.nix_example + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3core/src/codegen/irrt/error_context.rs b/nac3core/src/codegen/irrt/error_context.rs new file mode 100644 index 00000000..aca77c3e --- /dev/null +++ b/nac3core/src/codegen/irrt/error_context.rs @@ -0,0 +1,81 @@ +use crate::codegen::model::*; + +pub struct StrFields<'ctx> { + content: Field>>, + length: Field>, +} + +#[derive(Debug, Clone)] +pub struct Str<'ctx> { + sizet: IntModel<'ctx>, +} + +impl<'ctx> IsStruct<'ctx> for Str<'ctx> { + type Fields = StrFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "Str" + } + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + Self::Fields { + content: builder.add_field_auto("content"), + length: builder.add_field("length", self.sizet), + } + } +} + +type ErrorId = Int32; +pub struct ErrorIdsFields { + index_error: Field>, + value_error: Field>, + assertion_error: Field>, + runtime_error: Field>, +} +#[derive(Debug, Clone)] +pub struct ErrorIds; + +impl<'ctx> IsStruct<'ctx> for ErrorIds { + type Fields = ErrorIdsFields; + + fn struct_name(&self) -> &'static str { + "ErrorIds" + } + + fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields { + Self::Fields { + index_error: builder.add_field_auto("index_error"), + value_error: builder.add_field_auto("value_error"), + assertion_error: builder.add_field_auto("assertion_error"), + runtime_error: builder.add_field_auto("runtime_error"), + } + } +} + +pub struct ErrorContextFields { + error_id: Field>, + message_template: Field>>, + param1: Field>, + param2: Field>, + param3: Field>, +} +#[derive(Debug, Clone)] +pub struct ErrorContext; + +impl<'ctx> IsStruct<'ctx> for ErrorContext { + type Fields = ErrorContextFields; + + fn struct_name(&self) -> &'static str { + "ErrorIds" + } + + fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields { + Self::Fields { + error_id: builder.add_field_auto("error_id"), + message_template: builder.add_field_auto("message_template"), + param1: builder.add_field_auto("param1"), + param2: builder.add_field_auto("param2"), + param3: builder.add_field_auto("param3"), + } + } +} diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs new file mode 100644 index 00000000..4ca8882c --- /dev/null +++ b/nac3core/src/codegen/irrt/util.rs @@ -0,0 +1,79 @@ +use inkwell::{ + types::{BasicMetadataTypeEnum, BasicType, IntType}, + values::{AnyValue, BasicMetadataValueEnum}, +}; + +use crate::{ + codegen::{model::*, CodeGenContext}, + util::SizeVariant, +}; + +fn get_size_variant(ty: IntType) -> SizeVariant { + match ty.get_bit_width() { + 32 => SizeVariant::Bits32, + 64 => SizeVariant::Bits64, + _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), + } +} + +#[must_use] +pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String { + let mut fn_name = fn_name.to_owned(); + match get_size_variant(ty) { + SizeVariant::Bits32 => { + // Do nothing, `fn_name` already has the correct name + } + SizeVariant::Bits64 => { + // Append "64", this is the naming convention + fn_name.push_str("64"); + } + } + fn_name +} + +// TODO: Variadic argument? +pub struct FunctionBuilder<'ctx, 'a> { + ctx: &'a CodeGenContext<'ctx, 'a>, + fn_name: &'a str, + arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>, +} + +impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { + pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self { + FunctionBuilder { ctx, fn_name, arguments: Vec::new() } + } + + // The name is for self-documentation + #[must_use] + pub fn arg>(mut self, _name: &'static str, model: &M, value: &M::Value) -> Self { + self.arguments + .push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into())); + self + } + + pub fn returning>(self, name: &'static str, return_model: &M) -> S::MemoryValue { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = return_model.get_llvm_type(self.ctx.ctx); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); + return_model.check_llvm_value(ret.as_any_value_enum()) + } + + // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None + pub fn returning_void(self) { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = self.ctx.ctx.void_type(); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + self.ctx.builder.build_call(function, ¶m_vals, "").unwrap(); + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 17952369..85b963bb 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -41,6 +41,7 @@ pub mod extern_fns; mod generator; pub mod irrt; pub mod llvm_intrinsics; +pub mod model; pub mod numpy; pub mod stmt; diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs new file mode 100644 index 00000000..cfdad12f --- /dev/null +++ b/nac3core/src/codegen/model/core.rs @@ -0,0 +1,37 @@ +/* +MemoryGetter a + { load : Memory -> a + } + +MemorySetter a + { store : Memory -> a -> Memory + } +*/ + +use inkwell::{ + context::Context, + types::BasicTypeEnum, + values::{AnyValueEnum, BasicValueEnum}, +}; + +use crate::codegen::CodeGenContext; + +use super::Pointer; + +pub trait ModelValue<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx>; +} + +pub trait Model<'ctx>: Clone { + type Value: ModelValue<'ctx>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value; + + fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> { + Pointer { + element: self.clone(), + value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(), + } + } +} diff --git a/nac3core/src/codegen/model/gep.rs b/nac3core/src/codegen/model/gep.rs new file mode 100644 index 00000000..f54c0f0b --- /dev/null +++ b/nac3core/src/codegen/model/gep.rs @@ -0,0 +1,138 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, StructType}, + values::{AnyValueEnum, BasicValue, BasicValueEnum, StructValue}, +}; +use itertools::Itertools; + +use crate::codegen::CodeGenContext; + +use super::{Model, ModelValue, Pointer}; + +#[derive(Debug, Clone)] +pub struct Field { + pub gep_index: u64, + pub name: &'static str, + pub element: E, +} + +// Like [`Field`] but element must be [`BasicTypeEnum<'ctx>`] +#[derive(Debug)] +struct FieldLLVM<'ctx> { + gep_index: u64, + name: &'ctx str, + llvm_type: BasicTypeEnum<'ctx>, +} + +#[derive(Debug)] +pub struct FieldBuilder<'ctx> { + pub ctx: &'ctx Context, + gep_index_counter: u64, + struct_name: &'ctx str, + fields: Vec>, +} + +impl<'ctx> FieldBuilder<'ctx> { + #[must_use] + pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self { + FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() } + } + + fn next_gep_index(&mut self) -> u64 { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + index + } + + pub fn add_field>(&mut self, name: &'static str, element: E) -> Field { + let gep_index = self.next_gep_index(); + + self.fields.push(FieldLLVM { gep_index, name, llvm_type: element.get_llvm_type(self.ctx) }); + + Field { gep_index, name, element } + } + + pub fn add_field_auto + Default>(&mut self, name: &'static str) -> Field { + self.add_field(name, E::default()) + } +} + +pub trait IsStruct<'ctx>: Clone { + type Fields; + + fn struct_name(&self) -> &'static str; + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields; + + fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder) + } + + fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder); // Self::Fields is discarded + + let field_types = + builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec(); + ctx.struct_type(&field_types, false) + } +} + +// To play nice with Rust's trait resolution +#[derive(Debug, Clone)] +pub struct StructModel(pub S); + +// TODO: enrich it +pub struct Struct<'ctx, S> { + pub structure: S, + pub value: StructValue<'ctx>, +} + +impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel { + type Value = Struct<'ctx, S>; // TODO: enrich it + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_struct_type(ctx).as_basic_type_enum() + } + + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + // TODO: check structure + Struct { structure: self.0.clone(), value: value.into_struct_value() } + } +} + +impl<'ctx, S: IsStruct<'ctx>> Pointer<'ctx, StructModel> { + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetFieldFn, + ) -> Pointer<'ctx, E> + where + E: Model<'ctx>, + GetFieldFn: FnOnce(S::Fields) -> Field, + { + let fields = self.element.0.get_fields(ctx.ctx); + let field = get_field(fields); + + let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that + + let ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + self.value, + &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + field.name, + ) + .unwrap() + }; + + Pointer { element: field.element, value: ptr } + } +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs new file mode 100644 index 00000000..ebf06408 --- /dev/null +++ b/nac3core/src/codegen/model/int.rs @@ -0,0 +1,121 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType}, + values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue}, +}; + +use super::core::*; + +#[derive(Debug, Clone, Copy)] +pub struct IntModel<'ctx>(pub IntType<'ctx>); +pub struct Int<'ctx>(pub IntValue<'ctx>); + +impl<'ctx> ModelValue<'ctx> for Int<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.0.as_basic_value_enum() + } +} + +impl<'ctx> Model<'ctx> for IntModel<'ctx> { + type Value = Int<'ctx>; + + fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() + } + + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + let int = value.into_int_value(); + assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width()); + Int(int) + } +} + +#[derive(Debug, Clone, Default)] +pub struct FixedIntModel(pub T); +pub struct FixedInt<'ctx, T: IsFixedInt> { + pub int: T, + pub value: IntValue<'ctx>, +} + +pub trait IsFixedInt: Clone + Default { + fn get_int_type(ctx: &Context) -> IntType<'_>; + fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type +} + +impl<'ctx, T: IsFixedInt> ModelValue<'ctx> for FixedInt<'ctx, T> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel { + type Value = FixedInt<'ctx, T>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + T::get_int_type(ctx).as_basic_type_enum() + } + + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + let value = value.into_int_value(); + assert_eq!(value.get_type().get_bit_width(), T::get_bit_width()); + FixedInt { int: self.0.clone(), value } + } +} + +impl<'ctx, T: IsFixedInt> FixedIntModel { + pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> { + FixedInt { int: self.0.clone(), value: T::get_int_type(ctx).const_int(value, false) } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Bool; + +impl IsFixedInt for Bool { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.bool_type() + } + + fn get_bit_width() -> u32 { + 1 + } +} + +#[derive(Debug, Clone, Default)] +pub struct Byte; + +impl IsFixedInt for Byte { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i8_type() + } + + fn get_bit_width() -> u32 { + 8 + } +} + +#[derive(Debug, Clone, Default)] +pub struct Int32; + +impl IsFixedInt for Int32 { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i32_type() + } + + fn get_bit_width() -> u32 { + 32 + } +} + +#[derive(Debug, Clone, Default)] +pub struct Int64; + +impl IsFixedInt for Int64 { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i64_type() + } + + fn get_bit_width() -> u32 { + 64 + } +} diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs new file mode 100644 index 00000000..2ce82e0e --- /dev/null +++ b/nac3core/src/codegen/model/mod.rs @@ -0,0 +1,11 @@ +pub mod core; +pub mod gep; +pub mod int; +pub mod pointer; +pub mod slice; + +pub use core::*; +pub use gep::*; +pub use int::*; +pub use pointer::*; +pub use slice::*; diff --git a/nac3core/src/codegen/model/pointer.rs b/nac3core/src/codegen/model/pointer.rs new file mode 100644 index 00000000..45c6ac6f --- /dev/null +++ b/nac3core/src/codegen/model/pointer.rs @@ -0,0 +1,74 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum}, + values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use crate::codegen::CodeGenContext; + +use super::core::*; + +pub struct Pointer<'ctx, E: Model<'ctx>> { + pub element: E, + pub value: PointerValue<'ctx>, +} + +#[derive(Debug, Clone, Default)] +pub struct PointerModel(pub E); + +impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: &E::Value) { + ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap(); + } + + pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value { + let val = ctx.builder.build_load(self.value, name).unwrap(); + self.element.check_llvm_value(val.as_any_value_enum()) + } +} + +impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel { + type Value = Pointer<'ctx, E>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() + } + + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + // TODO: Check get_element_type()? for LLVM 14 at least... + Pointer { element: self.0.clone(), value: value.into_pointer_value() } + } +} + +pub struct OpaquePointer<'ctx>(pub PointerValue<'ctx>); + +#[derive(Debug, Clone, Default)] +pub struct OpaquePointerModel; + +impl<'ctx> ModelValue<'ctx> for OpaquePointer<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.0.as_basic_value_enum() + } +} + +impl<'ctx> Model<'ctx> for OpaquePointerModel { + type Value = OpaquePointer<'ctx>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum() + } + + fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + let ptr = value.into_pointer_value(); + // TODO: remove this check once LLVM pointers do not have `get_element_type()` + assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8); + OpaquePointer(ptr) + } +} diff --git a/nac3core/src/codegen/model/slice.rs b/nac3core/src/codegen/model/slice.rs new file mode 100644 index 00000000..8e3cefa0 --- /dev/null +++ b/nac3core/src/codegen/model/slice.rs @@ -0,0 +1,75 @@ +use inkwell::values::IntValue; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::{Int, Model, Pointer}; + +pub struct ArraySlice<'ctx, E: Model<'ctx>> { + pub num_elements: Int<'ctx>, + pub pointer: Pointer<'ctx, E>, +} + +impl<'ctx, E: Model<'ctx>> ArraySlice<'ctx, E> { + pub fn ix_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: &str, + ) -> Pointer<'ctx, E> { + let element_addr = + unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx], name).unwrap() }; + Pointer { value: element_addr, element: self.pointer.element.clone() } + } + + pub fn ix( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: &str, + ) -> Pointer<'ctx, E> { + let int_type = self.num_elements.0.get_type(); // NOTE: Weird get_type(), see comment under `trait Ixed` + + assert_eq!(int_type.get_bit_width(), idx.get_type().get_bit_width()); // Might as well check bit width to catch bugs + + // TODO: SGE or UGE? or make it defined by the implementee? + + // Check `0 <= index` + let lower_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLE, + int_type.const_zero(), + idx, + "lower_bounded", + ) + .unwrap(); + + // Check `index < num_elements` + let upper_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLT, + idx, + self.num_elements.0, + "upper_bounded", + ) + .unwrap(); + + // Compute `0 <= index && index < num_elements` + let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap(); + + // Assert `bounded` + ctx.make_assert( + generator, + bounded, + "0:IndexError", + "nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}", + [ Some(idx), Some(self.num_elements.0), None], + ctx.current_loc + ); + + // ...and finally do indexing + self.ix_unchecked(ctx, idx, name) + } +}