From 7742fbf9e0ee0b8e2e4b7e54dd608cb1b208146c Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 14 Jul 2024 16:10:04 +0800 Subject: [PATCH] core: split codegen irrt into modules --- nac3core/src/codegen/irrt/classes.rs | 136 ---------------- .../codegen/irrt/{new.rs => error_context.rs} | 147 ++++++++++-------- nac3core/src/codegen/irrt/mod.rs | 9 +- nac3core/src/codegen/irrt/numpy.rs | 55 ++++++- nac3core/src/codegen/irrt/util.rs | 77 +++++++++ nac3core/src/codegen/mod.rs | 2 +- 6 files changed, 209 insertions(+), 217 deletions(-) delete mode 100644 nac3core/src/codegen/irrt/classes.rs rename nac3core/src/codegen/irrt/{new.rs => error_context.rs} (56%) create mode 100644 nac3core/src/codegen/irrt/util.rs diff --git a/nac3core/src/codegen/irrt/classes.rs b/nac3core/src/codegen/irrt/classes.rs deleted file mode 100644 index 520ba316..00000000 --- a/nac3core/src/codegen/irrt/classes.rs +++ /dev/null @@ -1,136 +0,0 @@ -use inkwell::types::IntType; - -use crate::codegen::optics::*; -use crate::codegen::CodeGenContext; - -#[derive(Debug, Clone)] -pub struct StrLens<'ctx> { - pub size_type: IntType<'ctx>, -} - -// TODO: nac3core has hardcoded a lot of "str" -pub struct StrFields<'ctx> { - pub content: GepGetter>>, - pub length: GepGetter>, -} - -impl<'ctx> StructureOptic<'ctx> for StrLens<'ctx> { - type Fields = StrFields<'ctx>; - - fn struct_name(&self) -> &'static str { - "str" - } - - fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { - StrFields { - content: builder.add_field("content", AddressLens(IntLens(builder.ctx.i8_type()))), - length: builder.add_field("length", IntLens(self.size_type)), - } - } -} - -pub struct NpArrayFields<'ctx> { - pub data: GepGetter>>, - pub itemsize: GepGetter>, - pub ndims: GepGetter>, - pub shape: GepGetter>>, - pub strides: GepGetter>>, -} - -#[derive(Debug, Clone, Copy)] -pub struct NpArrayLens<'ctx> { - pub size_type: IntType<'ctx>, -} - -impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { - type Fields = NpArrayFields<'ctx>; - - fn struct_name(&self) -> &'static str { - "NDArray" - } - - fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { - NpArrayFields { - data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))), - itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())), - ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())), - shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))), - strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))), - } - } -} - -// Other convenient utilities for NpArray -impl<'ctx> Address<'ctx, NpArrayLens<'ctx>> { - pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { - let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); - let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape"); - ArraySlice { num_elements: ndims, base: shape_base_ptr } - } - - pub fn strides_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { - let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); - let strides_base_ptr = self.focus(ctx, |fields| &fields.strides).load(ctx, "strides"); - ArraySlice { num_elements: ndims, base: strides_base_ptr } - } -} - -pub struct ErrorIdsFields<'ctx> { - pub index_error: GepGetter>, - pub value_error: GepGetter>, - pub assertion_error: GepGetter>, - pub runtime_error: GepGetter>, -} - -#[derive(Debug, Clone)] -pub struct ErrorIdsLens; - -impl<'ctx> StructureOptic<'ctx> for ErrorIdsLens { - type Fields = ErrorIdsFields<'ctx>; - - fn struct_name(&self) -> &'static str { - "ErrorIds" - } - - fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { - let i32_lens = IntLens(builder.ctx.i32_type()); - ErrorIdsFields { - index_error: builder.add_field("index_error", i32_lens), - value_error: builder.add_field("value_error", i32_lens), - assertion_error: builder.add_field("assertion_error", i32_lens), - runtime_error: builder.add_field("runtime_error", i32_lens), - } - } -} - -pub struct ErrorContextFields<'ctx> { - pub error_ids: GepGetter>, - pub error_id: GepGetter>, - pub message_template: GepGetter>>, - pub param1: GepGetter>, - pub param2: GepGetter>, - pub param3: GepGetter>, -} - -#[derive(Debug, Clone, Copy)] -pub struct ErrorContextLens; - -impl<'ctx> StructureOptic<'ctx> for ErrorContextLens { - type Fields = ErrorContextFields<'ctx>; - - fn struct_name(&self) -> &'static str { - "ErrorContext" - } - - fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { - ErrorContextFields { - error_ids: builder.add_field("error_ids", AddressLens(ErrorIdsLens)), - error_id: builder.add_field("error_id", IntLens(builder.ctx.i32_type())), - message_template: builder - .add_field("message_template", AddressLens(IntLens(builder.ctx.i8_type()))), - param1: builder.add_field("param1", IntLens(builder.ctx.i64_type())), - param2: builder.add_field("param2", IntLens(builder.ctx.i64_type())), - param3: builder.add_field("param3", IntLens(builder.ctx.i64_type())), - } - } -} diff --git a/nac3core/src/codegen/irrt/new.rs b/nac3core/src/codegen/irrt/error_context.rs similarity index 56% rename from nac3core/src/codegen/irrt/new.rs rename to nac3core/src/codegen/irrt/error_context.rs index 6fb90b18..310223cf 100644 --- a/nac3core/src/codegen/irrt/new.rs +++ b/nac3core/src/codegen/irrt/error_context.rs @@ -1,87 +1,96 @@ -use inkwell::{ - types::{BasicMetadataTypeEnum, BasicType, IntType}, - values::{AnyValue, BasicMetadataValueEnum, IntValue}, -}; +use inkwell::types::IntType; +use inkwell::values::IntValue; -use crate::codegen::{ - optics::{ - address::{Address, AddressLens}, - core::{Optic, OpticValue, Prism}, - int::IntLens, - }, - CodeGenContext, CodeGenerator, -}; -use crate::util::SizeVariant; +use crate::codegen::optics::*; +use crate::codegen::CodeGenContext; +use crate::codegen::CodeGenerator; -use super::classes::{ErrorContextLens, ErrorIdsLens, StrLens}; +use super::util::get_sized_dependent_function_name; +use super::util::FunctionBuilder; -fn get_size_variant(ty: IntType) -> SizeVariant { - match ty.get_bit_width() { - 32 => SizeVariant::Bits32, - 64 => SizeVariant::Bits64, - _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), - } +#[derive(Debug, Clone)] +pub struct StrLens<'ctx> { + pub size_type: IntType<'ctx>, } -#[must_use] -pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String { - let mut fn_name = fn_name.to_owned(); - match get_size_variant(ty) { - SizeVariant::Bits32 => { - // Do nothing, `fn_name` already has the correct name - } - SizeVariant::Bits64 => { - // Append "64", this is the naming convention - fn_name.push_str("64"); +// TODO: nac3core has hardcoded a lot of "str" +pub struct StrFields<'ctx> { + pub content: GepGetter>>, + pub length: GepGetter>, +} + +impl<'ctx> StructureOptic<'ctx> for StrLens<'ctx> { + type Fields = StrFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "str" + } + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + StrFields { + content: builder.add_field("content", AddressLens(IntLens(builder.ctx.i8_type()))), + length: builder.add_field("length", IntLens(self.size_type)), } } - fn_name } -// TODO: Variadic argument? -pub struct FunctionBuilder<'ctx, 'a> { - ctx: &'a CodeGenContext<'ctx, 'a>, - fn_name: &'a str, - arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>, +pub struct ErrorIdsFields<'ctx> { + pub index_error: GepGetter>, + pub value_error: GepGetter>, + pub assertion_error: GepGetter>, + pub runtime_error: GepGetter>, } -impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { - pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self { - FunctionBuilder { ctx, fn_name, arguments: Vec::new() } +#[derive(Debug, Clone)] +pub struct ErrorIdsLens; + +impl<'ctx> StructureOptic<'ctx> for ErrorIdsLens { + type Fields = ErrorIdsFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "ErrorIds" } - // The name is for self-documentation - #[must_use] - pub fn arg>(mut self, _name: &'static str, optic: &S, arg: &S::Value) -> Self { - self.arguments - .push((optic.get_llvm_type(self.ctx.ctx).into(), arg.get_llvm_value().into())); - self + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + let i32_lens = IntLens(builder.ctx.i32_type()); + ErrorIdsFields { + index_error: builder.add_field("index_error", i32_lens), + value_error: builder.add_field("value_error", i32_lens), + assertion_error: builder.add_field("assertion_error", i32_lens), + runtime_error: builder.add_field("runtime_error", i32_lens), + } + } +} + +pub struct ErrorContextFields<'ctx> { + pub error_ids: GepGetter>, + pub error_id: GepGetter>, + pub message_template: GepGetter>>, + pub param1: GepGetter>, + pub param2: GepGetter>, + pub param3: GepGetter>, +} + +#[derive(Debug, Clone, Copy)] +pub struct ErrorContextLens; + +impl<'ctx> StructureOptic<'ctx> for ErrorContextLens { + type Fields = ErrorContextFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "ErrorContext" } - pub fn returning>(self, name: &'static str, return_prism: &S) -> S::Value { - let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); - - let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { - let return_type = return_prism.get_llvm_type(self.ctx.ctx); - let fn_type = return_type.fn_type(¶m_tys, false); - self.ctx.module.add_function(self.fn_name, fn_type, None) - }); - - let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); - return_prism.review(ret.as_any_value_enum()) - } - - // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None - pub fn returning_void(self) { - let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); - - let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { - let return_type = self.ctx.ctx.void_type(); - let fn_type = return_type.fn_type(¶m_tys, false); - self.ctx.module.add_function(self.fn_name, fn_type, None) - }); - - self.ctx.builder.build_call(function, ¶m_vals, "").unwrap(); + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + ErrorContextFields { + error_ids: builder.add_field("error_ids", AddressLens(ErrorIdsLens)), + error_id: builder.add_field("error_id", IntLens(builder.ctx.i32_type())), + message_template: builder + .add_field("message_template", AddressLens(IntLens(builder.ctx.i8_type()))), + param1: builder.add_field("param1", IntLens(builder.ctx.i64_type())), + param2: builder.add_field("param2", IntLens(builder.ctx.i64_type())), + param3: builder.add_field("param3", IntLens(builder.ctx.i64_type())), + } } } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 57d4deb9..329b3826 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,8 +1,5 @@ use crate::typecheck::typedef::Type; -pub mod numpy; -mod test; - use super::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, @@ -24,8 +21,10 @@ use inkwell::{ use itertools::Either; use nac3parser::ast::Expr; -pub mod classes; -pub mod new; +pub mod error_context; +pub mod numpy; +pub mod test; +pub mod util; #[must_use] pub fn load_irrt(ctx: &Context) -> Module { diff --git a/nac3core/src/codegen/irrt/numpy.rs b/nac3core/src/codegen/irrt/numpy.rs index b4827f31..9a16c85f 100644 --- a/nac3core/src/codegen/irrt/numpy.rs +++ b/nac3core/src/codegen/irrt/numpy.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use inkwell::{ - types::{BasicType, BasicTypeEnum}, + types::{BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue}, }; @@ -16,13 +16,56 @@ use crate::{ }; use super::{ - classes::{ErrorContextLens, NpArrayLens}, - new::{ - check_error_context, get_sized_dependent_function_name, prepare_error_context, - FunctionBuilder, - }, + error_context::{check_error_context, prepare_error_context, ErrorContextLens}, + util::{get_sized_dependent_function_name, FunctionBuilder}, }; +pub struct NpArrayFields<'ctx> { + pub data: GepGetter>>, + pub itemsize: GepGetter>, + pub ndims: GepGetter>, + pub shape: GepGetter>>, + pub strides: GepGetter>>, +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayLens<'ctx> { + pub size_type: IntType<'ctx>, +} + +impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { + type Fields = NpArrayFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "NDArray" + } + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields { + NpArrayFields { + data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))), + itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())), + ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())), + shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))), + strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))), + } + } +} + +// Other convenient utilities for NpArray +impl<'ctx> Address<'ctx, NpArrayLens<'ctx>> { + pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { + let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); + let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape"); + ArraySlice { num_elements: ndims, base: shape_base_ptr } + } + + pub fn strides_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { + let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); + let strides_base_ptr = self.focus(ctx, |fields| &fields.strides).load(ctx, "strides"); + ArraySlice { num_elements: ndims, base: strides_base_ptr } + } +} + type ProducerWriteToArray<'ctx, G, ElementOptic> = Box< dyn Fn( &mut G, diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs new file mode 100644 index 00000000..1a9c448a --- /dev/null +++ b/nac3core/src/codegen/irrt/util.rs @@ -0,0 +1,77 @@ +use inkwell::{ + types::{BasicMetadataTypeEnum, BasicType, IntType}, + values::{AnyValue, BasicMetadataValueEnum}, +}; + +use crate::codegen::optics::*; +use crate::{codegen::CodeGenContext, util::SizeVariant}; + +fn get_size_variant(ty: IntType) -> SizeVariant { + match ty.get_bit_width() { + 32 => SizeVariant::Bits32, + 64 => SizeVariant::Bits64, + _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), + } +} + +#[must_use] +pub fn get_sized_dependent_function_name(ty: IntType, fn_name: &str) -> String { + let mut fn_name = fn_name.to_owned(); + match get_size_variant(ty) { + SizeVariant::Bits32 => { + // Do nothing, `fn_name` already has the correct name + } + SizeVariant::Bits64 => { + // Append "64", this is the naming convention + fn_name.push_str("64"); + } + } + fn_name +} + +// TODO: Variadic argument? +pub struct FunctionBuilder<'ctx, 'a> { + ctx: &'a CodeGenContext<'ctx, 'a>, + fn_name: &'a str, + arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>, +} + +impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { + pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self { + FunctionBuilder { ctx, fn_name, arguments: Vec::new() } + } + + // The name is for self-documentation + #[must_use] + pub fn arg>(mut self, _name: &'static str, optic: &S, arg: &S::Value) -> Self { + self.arguments + .push((optic.get_llvm_type(self.ctx.ctx).into(), arg.get_llvm_value().into())); + self + } + + pub fn returning>(self, name: &'static str, return_prism: &S) -> S::Value { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = return_prism.get_llvm_type(self.ctx.ctx); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); + return_prism.review(ret.as_any_value_enum()) + } + + // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None + pub fn returning_void(self) { + let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); + + let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { + let return_type = self.ctx.ctx.void_type(); + let fn_type = return_type.fn_type(¶m_tys, false); + self.ctx.module.add_function(self.fn_name, fn_type, None) + }); + + self.ctx.builder.build_call(function, ¶m_vals, "").unwrap(); + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index b5d7a2d1..7f38e43c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -23,7 +23,7 @@ use inkwell::{ values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use irrt::classes::StrLens; +use irrt::error_context::StrLens; use itertools::Itertools; use nac3parser::ast::{Location, Stmt, StrRef}; use optics::Optic as _;