From 3cd36fddc30b465efd348333ec3777130e978810 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 9 Dec 2024 14:11:41 +0800 Subject: [PATCH] [core] codegen/types: Add check_struct_type_matches_fields Shorthand for checking if a type is representable by a StructFields instance. --- nac3core/src/codegen/types/structure.rs | 48 ++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index adfc53ab..4622e9b2 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, - types::{BasicTypeEnum, IntType}, + types::{BasicTypeEnum, IntType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, }; @@ -201,3 +201,49 @@ impl FieldIndexCounter { v } } + +type FieldTypeVerifier<'ctx> = dyn Fn(BasicTypeEnum<'ctx>) -> Result<(), String>; + +/// Checks whether [`llvm_ty`][StructType] contains the fields described by the given +/// [`StructFields`] instance. +/// +/// By default, this function will compare the type of each field in `expected_fields` against +/// `llvm_ty`. To override this behavior for individual fields, pass in overrides to +/// `custom_verifiers`, which will use the specified verifier when a field with the matching field +/// name is being checked. +pub(super) fn check_struct_type_matches_fields<'ctx>( + expected_fields: impl StructFields<'ctx>, + llvm_ty: StructType<'ctx>, + ty_name: &'static str, + custom_verifiers: &[(&str, &FieldTypeVerifier<'ctx>)], +) -> Result<(), String> { + let expected_fields = expected_fields.to_vec(); + + if llvm_ty.count_fields() != u32::try_from(expected_fields.len()).unwrap() { + return Err(format!( + "Expected {} fields in `{ty_name}`, got {}", + expected_fields.len(), + llvm_ty.count_fields(), + )); + } + + expected_fields + .into_iter() + .enumerate() + .map(|(i, (field_name, expected_ty))| { + (field_name, expected_ty, llvm_ty.get_field_type_at_index(i as u32).unwrap()) + }) + .try_for_each(|(field_name, expected_ty, actual_ty)| { + if let Some((_, verifier)) = + custom_verifiers.iter().find(|verifier| verifier.0 == field_name) + { + verifier(actual_ty) + } else if expected_ty == actual_ty { + Ok(()) + } else { + Err(format!("Expected {expected_ty} for `{ty_name}.{field_name}`, got {actual_ty}")) + } + })?; + + Ok(()) +}