From 148900302e92d2b9dc2aedaefda98415d80afae9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 23 Jan 2024 18:27:00 +0800 Subject: [PATCH] core: Add RangeValue and helper functions --- nac3core/src/codegen/classes.rs | 159 +++++++++++++++++++++++++++++- nac3core/src/codegen/expr.rs | 19 ++-- nac3core/src/codegen/stmt.rs | 4 +- nac3core/src/toplevel/builtins.rs | 3 +- 4 files changed, 168 insertions(+), 17 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 66f74ac4..4f8e7851 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -28,7 +28,7 @@ impl<'ctx> ListValue<'ctx> { ) -> Result<(), String> { let llvm_list_ty = value.get_type().get_element_type(); let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - panic!("Expected struct type for `list` type, got {llvm_list_ty}") + return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")) }; if llvm_list_ty.count_fields() != 2 { return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields())) @@ -223,3 +223,160 @@ impl<'ctx> ListDataProxy<'ctx> { ctx.builder.build_load(ptr, name.unwrap_or_default()) } } + +#[cfg(not(debug_assertions))] +pub fn assert_is_range(_value: PointerValue) {} + +#[cfg(debug_assertions)] +pub fn assert_is_range(value: PointerValue) { + if let Err(msg) = RangeValue::is_instance(value) { + panic!("{msg}") + } +} + +/// Proxy type for accessing a `range` value in LLVM. +#[derive(Copy, Clone)] +pub struct RangeValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> RangeValue<'ctx> { + /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. + pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> { + let llvm_range_ty = value.get_type().get_element_type(); + let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { + return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")) + }; + if llvm_range_ty.len() != 3 { + return Err(format!("Expected 3 elements for `range` type, got {}", llvm_range_ty.len())) + } + + let llvm_range_elem_ty = llvm_range_ty.get_element_type(); + let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else { + return Err(format!("Expected int type for `range` element type, got {llvm_range_elem_ty}")) + }; + if llvm_range_elem_ty.get_bit_width() != 32 { + return Err(format!("Expected 32-bit int type for `range` element type, got {}", + llvm_range_elem_ty.get_bit_width())) + } + + Ok(()) + } + + /// Creates an [RangeValue] from a [PointerValue]. + pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { + assert_is_range(ptr); + RangeValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `range` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + fn get_start_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.start.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], + var_name.as_str(), + ) + } + } + + fn get_end_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.end.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + var_name.as_str(), + ) + } + } + + fn get_step_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.step.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], + var_name.as_str(), + ) + } + } + + /// Stores the `start` value into this instance. + pub fn store_start( + &self, + ctx: &CodeGenContext<'ctx, '_>, + start: IntValue<'ctx>, + ) { + debug_assert_eq!(start.get_type().get_bit_width(), 32); + + let pstart = self.get_start_ptr(ctx); + ctx.builder.build_store(pstart, start); + } + + /// Returns the `start` value of this `range`. + pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstart = self.get_start_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.start"))) + .unwrap_or_default(); + + ctx.builder.build_load(pstart, var_name.as_str()).into_int_value() + } + + /// Stores the `end` value into this instance. + pub fn store_end( + &self, + ctx: &CodeGenContext<'ctx, '_>, + end: IntValue<'ctx>, + ) { + debug_assert_eq!(end.get_type().get_bit_width(), 32); + + let pend = self.get_start_ptr(ctx); + ctx.builder.build_store(pend, end); + } + + /// Returns the `end` value of this `range`. + pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pend = self.get_end_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.end"))) + .unwrap_or_default(); + + ctx.builder.build_load(pend, var_name.as_str()).into_int_value() + } + + /// Stores the `step` value into this instance. + pub fn store_step( + &self, + ctx: &CodeGenContext<'ctx, '_>, + step: IntValue<'ctx>, + ) { + debug_assert_eq!(step.get_type().get_bit_width(), 32); + + let pstep = self.get_start_ptr(ctx); + ctx.builder.build_store(pstep, step); + } + + /// Returns the `step` value of this `range`. + pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstep = self.get_step_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.step"))) + .unwrap_or_default(); + + ctx.builder.build_load(pstep, var_name.as_str()).into_int_value() + } +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 734e9451..58f090f5 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, RangeValue}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_type, @@ -870,18 +870,11 @@ pub fn gen_call<'ctx, G: CodeGenerator>( /// respectively. pub fn destructure_range<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - range: PointerValue<'ctx>, + range: RangeValue<'ctx>, ) -> (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>) { - let int32 = ctx.ctx.i32_type(); - let start = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(0, false)], Some("range.start")) - .into_int_value(); - let end = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(1, false)], Some("range.stop")) - .into_int_value(); - let step = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(2, false)], Some("range.step")) - .into_int_value(); + let start = range.load_start(ctx, None); + let end = range.load_end(ctx, None); + let step = range.load_step(ctx, None); (start, end, step) } @@ -965,7 +958,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let list_content; if is_range { - let iter_val = iter_val.into_pointer_value(); + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff"); // add 1 to the length as the value is rounded to zero diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 8ac5a7db..b9fcda01 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -6,7 +6,7 @@ use super::{ }; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, RangeValue}, expr::gen_binop_expr, gen_in_range_check, }, @@ -321,7 +321,7 @@ pub fn gen_for( return Ok(()) }; if is_iterable_range_expr { - let iter_val = iter_val.into_pointer_value(); + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index d2eb458a..d4bf68cb 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,7 @@ use super::*; use crate::{ codegen::{ + classes::RangeValue, expr::destructure_range, irrt::{ calculate_len_for_slice_range, @@ -1453,7 +1454,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = arg.into_pointer_value(); + let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) } else {