[core] codegen: Refactor StructField getters and setters

This commit is contained in:
David Mak 2025-01-24 10:53:14 +08:00 committed by sb10q
parent b521bc0c82
commit eec62c3bbb
8 changed files with 70 additions and 49 deletions

View File

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use inkwell::{ use inkwell::{
context::AsContextRef, context::AsContextRef,
types::{BasicTypeEnum, IntType, PointerType, StructType}, types::{BasicTypeEnum, IntType, PointerType, StructType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, values::{AggregateValueEnum, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -203,17 +203,38 @@ where
/// Gets the value of this field for a given `obj`. /// Gets the value of this field for a given `obj`.
#[must_use] #[must_use]
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value { pub fn extract_value(&self, ctx: &CodeGenContext<'ctx, '_>, obj: StructValue<'ctx>) -> Value {
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap() Value::try_from(
ctx.builder
.build_extract_value(
obj,
self.index,
&format!("{}.{}", obj.get_name().to_str().unwrap(), self.name),
)
.unwrap(),
)
.unwrap()
} }
/// Sets the value of this field for a given `obj`. /// Sets the value of this field for a given `obj`.
pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) { #[must_use]
obj.set_field_at_index(self.index, value); pub fn insert_value(
&self,
ctx: &CodeGenContext<'ctx, '_>,
obj: StructValue<'ctx>,
value: Value,
) -> StructValue<'ctx> {
let obj_name = obj.get_name().to_str().unwrap();
let new_obj_name = if obj_name.chars().all(char::is_numeric) { "" } else { obj_name };
ctx.builder
.build_insert_value(obj, value, self.index, new_obj_name)
.map(AggregateValueEnum::into_struct_value)
.unwrap()
} }
/// Gets the value of this field for a pointer-to-structure. /// Loads the value of this field for a pointer-to-structure.
pub fn get( pub fn load(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>, pobj: PointerValue<'ctx>,
@ -229,8 +250,8 @@ where
.unwrap() .unwrap()
} }
/// Sets the value of this field for a pointer-to-structure. /// Stores the value of this field for a pointer-to-structure.
pub fn set( pub fn store(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>, pobj: PointerValue<'ctx>,

View File

@ -45,7 +45,7 @@ impl<'ctx> ListValue<'ctx> {
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
self.items_field(ctx).set(ctx, self.value, data, self.name); self.items_field(ctx).store(ctx, self.value, data, self.name);
} }
/// Convenience method for creating a new array storing data elements with the given element /// Convenience method for creating a new array storing data elements with the given element
@ -91,7 +91,7 @@ impl<'ctx> ListValue<'ctx> {
pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) {
debug_assert_eq!(size.get_type(), ctx.get_size_type()); debug_assert_eq!(size.get_type(), ctx.get_size_type());
self.len_field(ctx).set(ctx, self.value, size, self.name); self.len_field(ctx).store(ctx, self.value, size, self.name);
} }
/// Returns the size of this `list` as a value. /// Returns the size of this `list` as a value.
@ -100,7 +100,7 @@ impl<'ctx> ListValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
self.len_field(ctx).get(ctx, self.value, name) self.len_field(ctx).load(ctx, self.value, name)
} }
/// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`.

View File

@ -44,7 +44,7 @@ impl<'ctx> ShapeEntryValue<'ctx> {
/// Stores the number of dimensions into this value. /// Stores the number of dimensions into this value.
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.ndims_field().set(ctx, self.value, value, self.name); self.ndims_field().store(ctx, self.value, value, self.name);
} }
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
@ -53,7 +53,7 @@ impl<'ctx> ShapeEntryValue<'ctx> {
/// Stores the shape into this value. /// Stores the shape into this value.
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.shape_field().set(ctx, self.value, value, self.name); self.shape_field().store(ctx, self.value, value, self.name);
} }
} }

View File

@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> {
} }
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); self.ndims_field().store(ctx, self.as_abi_value(ctx), value, self.name);
} }
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
@ -49,11 +49,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> {
} }
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); self.shape_field().store(ctx, self.as_abi_value(ctx), value, self.name);
} }
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.shape_field().get(ctx, self.value, self.name) self.shape_field().load(ctx, self.value, self.name)
} }
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
@ -61,11 +61,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> {
} }
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); self.data_field().store(ctx, self.as_abi_value(ctx), value, self.name);
} }
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name) self.data_field().load(ctx, self.value, self.name)
} }
} }
@ -129,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| Ok(self.is_c_contiguous(ctx)),
|_, ctx| { |_, ctx| {
// This ndarray is contiguous. // This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); let data = self.data_field(ctx).load(ctx, self.as_abi_value(ctx), self.name);
let data = ctx let data = ctx
.builder .builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")

View File

@ -47,11 +47,11 @@ impl<'ctx> NDIndexValue<'ctx> {
} }
pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.type_field().get(ctx, self.value, self.name) self.type_field().load(ctx, self.value, self.name)
} }
pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.type_field().set(ctx, self.value, value, self.name); self.type_field().store(ctx, self.value, value, self.name);
} }
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
@ -59,11 +59,11 @@ impl<'ctx> NDIndexValue<'ctx> {
} }
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name) self.data_field().load(ctx, self.value, self.name)
} }
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.value, value, self.name); self.data_field().store(ctx, self.value, value, self.name);
} }
} }

View File

@ -94,12 +94,12 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) {
debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); debug_assert_eq!(itemsize.get_type(), ctx.get_size_type());
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); self.itemsize_field(ctx).store(ctx, self.value, itemsize, self.name);
} }
/// Returns the size of each element of this `NDArray` as a value. /// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.itemsize_field(ctx).get(ctx, self.value, self.name) self.itemsize_field(ctx).load(ctx, self.value, self.name)
} }
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Stores the array of dimension sizes `dims` into this instance. /// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
self.shape_field(ctx).set(ctx, self.value, dims, self.name); self.shape_field(ctx).store(ctx, self.value, dims, self.name);
} }
/// Convenience method for creating a new array storing dimension sizes with the given `size`. /// Convenience method for creating a new array storing dimension sizes with the given `size`.
@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Stores the array of stride sizes `strides` into this instance. /// Stores the array of stride sizes `strides` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) {
self.strides_field(ctx).set(ctx, self.value, strides, self.name); self.strides_field(ctx).store(ctx, self.value, strides, self.name);
} }
/// Convenience method for creating a new array storing the stride with the given `size`. /// Convenience method for creating a new array storing the stride with the given `size`.
@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> {
.builder .builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(); .unwrap();
self.data_field(ctx).set(ctx, self.value, data.into_pointer_value(), self.name); self.data_field(ctx).store(ctx, self.value, data.into_pointer_value(), self.name);
} }
/// Convenience method for creating a new array storing data elements with the given element /// Convenience method for creating a new array storing data elements with the given element
@ -508,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
self.0.shape_field(ctx).get(ctx, self.0.value, self.0.name) self.0.shape_field(ctx).load(ctx, self.0.value, self.0.name)
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
@ -606,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
self.0.strides_field(ctx).get(ctx, self.0.value, self.0.name) self.0.strides_field(ctx).load(ctx, self.0.value, self.0.name)
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
@ -704,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
self.0.data_field(ctx).get(ctx, self.0.value, self.0.name) self.0.data_field(ctx).load(ctx, self.0.value, self.0.name)
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(

View File

@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> {
pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let elem_ty = self.parent.dtype; let elem_ty = self.parent.dtype;
let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); let p = self.element_field(ctx).load(ctx, self.as_abi_value(ctx), self.name);
ctx.builder ctx.builder
.build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap() .unwrap()
@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> {
/// Get the index of the current element if this ndarray were a flat ndarray. /// Get the index of the current element if this ndarray were a flat ndarray.
#[must_use] #[must_use]
pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) self.nth_field(ctx).load(ctx, self.as_abi_value(ctx), self.name)
} }
/// Get the indices of the current element. /// Get the indices of the current element.

View File

@ -42,7 +42,7 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.start_defined_field().get(ctx, self.value, self.name) self.start_defined_field().load(ctx, self.value, self.name)
} }
fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> { fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
@ -50,22 +50,22 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.start_field().get(ctx, self.value, self.name) self.start_field().load(ctx, self.value, self.name)
} }
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) { pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value { match value {
Some(start) => { Some(start) => {
self.start_defined_field().set( self.start_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_all_ones(), ctx.ctx.bool_type().const_all_ones(),
self.name, self.name,
); );
self.start_field().set(ctx, self.value, start, self.name); self.start_field().store(ctx, self.value, start, self.name);
} }
None => self.start_defined_field().set( None => self.start_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),
@ -79,7 +79,7 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.stop_defined_field().get(ctx, self.value, self.name) self.stop_defined_field().load(ctx, self.value, self.name)
} }
fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> { fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
@ -87,22 +87,22 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.stop_field().get(ctx, self.value, self.name) self.stop_field().load(ctx, self.value, self.name)
} }
pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) { pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value { match value {
Some(stop) => { Some(stop) => {
self.stop_defined_field().set( self.stop_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_all_ones(), ctx.ctx.bool_type().const_all_ones(),
self.name, self.name,
); );
self.stop_field().set(ctx, self.value, stop, self.name); self.stop_field().store(ctx, self.value, stop, self.name);
} }
None => self.stop_defined_field().set( None => self.stop_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),
@ -116,7 +116,7 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.step_defined_field().get(ctx, self.value, self.name) self.step_defined_field().load(ctx, self.value, self.name)
} }
fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> { fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
@ -124,22 +124,22 @@ impl<'ctx> SliceValue<'ctx> {
} }
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.step_field().get(ctx, self.value, self.name) self.step_field().load(ctx, self.value, self.name)
} }
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) { pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value { match value {
Some(step) => { Some(step) => {
self.step_defined_field().set( self.step_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_all_ones(), ctx.ctx.bool_type().const_all_ones(),
self.name, self.name,
); );
self.step_field().set(ctx, self.value, step, self.name); self.step_field().store(ctx, self.value, step, self.name);
} }
None => self.step_defined_field().set( None => self.step_defined_field().store(
ctx, ctx,
self.value, self.value,
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),