diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index db1e9b66..385a6a49 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -31,7 +31,9 @@ use crate::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, - typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, + typedef::{ + iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap, + }, }, }; use inkwell::{ @@ -1061,98 +1063,132 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.build_store(index, zero_size_t).unwrap(); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); - let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); let list; let list_content; - if is_range { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); - let (start, stop, step) = destructure_range(ctx, iter_val); - let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); - // add 1 to the length as the value is rounded to zero - // the length may be 1 more than the actual length if the division is exact, but the - // length is a upper bound only anyway so it does not matter. - let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); - let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); - // in case length is non-positive - let is_valid = - ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); + // The implementation of the for loop logic depends on + // the typechecker type of `iter`. + let iter_ty = iter.custom.unwrap(); + match &*ctx.unifier.get_ty(iter_ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // `iter` is a `List[T]`, and `T` is the element type - let list_alloc_size = ctx - .builder - .build_select( - is_valid, - ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(), - zero_size_t, - "listcomp.alloc_size", + // Get the `T` out of `List[T]` - it is defined to be the 1st param. + let list_elem_ty = iter_type_vars(params).nth(0).unwrap().ty; + + let length = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, int32.const_int(1, false)], + Some("length"), + ) + .into_int_value(); + list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + list_content = list.data().base_ptr(ctx, generator); + let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; + // counter = -1 + ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)).unwrap(); + ctx.builder.build_unconditional_branch(test_bb).unwrap(); + + ctx.builder.position_at_end(test_bb); + let tmp = + ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap(); + let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); + ctx.builder.build_store(counter, tmp).unwrap(); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); + ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap(); + + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, zero_32], + Some("arr.addr"), + ) + .into_pointer_value(); + let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); + generator.gen_assign(ctx, target, val.into(), list_elem_ty)?; + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => + { + // `iter` is a `range(start, stop, step)`, and `int32` is the element type + + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + let (start, stop, step) = destructure_range(ctx, iter_val); + let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); + // add 1 to the length as the value is rounded to zero + // the length may be 1 more than the actual length if the division is exact, but the + // length is a upper bound only anyway so it does not matter. + let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); + let length = + ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); + // in case length is non-positive + let is_valid = + ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); + + let list_alloc_size = ctx + .builder + .build_select( + is_valid, + ctx.builder + .build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len") + .unwrap(), + zero_size_t, + "listcomp.alloc_size", + ) + .unwrap(); + list = allocate_list( + generator, + ctx, + Some(elem_ty), + list_alloc_size.into_int_value(), + Some("listcomp.addr"), + ); + list_content = list.data().base_ptr(ctx, generator); + + let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); + ctx.builder + .build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap()) + .unwrap(); + + ctx.builder + .build_conditional_branch( + gen_in_range_check(ctx, start, stop, step), + test_bb, + cont_bb, + ) + .unwrap(); + + ctx.builder.position_at_end(test_bb); + // add and test + let tmp = ctx + .builder + .build_int_add( + ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), + step, + "start_loop", + ) + .unwrap(); + ctx.builder.build_store(i, tmp).unwrap(); + ctx.builder + .build_conditional_branch( + gen_in_range_check(ctx, tmp, stop, step), + body_bb, + cont_bb, + ) + .unwrap(); + + ctx.builder.position_at_end(body_bb); + } + _ => { + panic!( + "unsupported iterator type in list comprehension: {}", + ctx.unifier.stringify(iter_ty) ) - .unwrap(); - list = allocate_list( - generator, - ctx, - Some(elem_ty), - list_alloc_size.into_int_value(), - Some("listcomp.addr"), - ); - list_content = list.data().base_ptr(ctx, generator); - - let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); - ctx.builder - .build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap()) - .unwrap(); - - ctx.builder - .build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb) - .unwrap(); - - ctx.builder.position_at_end(test_bb); - // add and test - let tmp = ctx - .builder - .build_int_add( - ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), - step, - "start_loop", - ) - .unwrap(); - ctx.builder.build_store(i, tmp).unwrap(); - ctx.builder - .build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb) - .unwrap(); - - ctx.builder.position_at_end(body_bb); - } else { - let length = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero_size_t, int32.const_int(1, false)], - Some("length"), - ) - .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); - list_content = list.data().base_ptr(ctx, generator); - let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; - // counter = -1 - ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)).unwrap(); - ctx.builder.build_unconditional_branch(test_bb).unwrap(); - - ctx.builder.position_at_end(test_bb); - let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap(); - let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); - ctx.builder.build_store(counter, tmp).unwrap(); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); - ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap(); - - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero_size_t, zero_32], - Some("arr.addr"), - ) - .into_pointer_value(); - let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); - generator.gen_assign(ctx, target, val.into())?; + } } // Emits the content of `cont_bb` @@ -2190,6 +2226,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( None => None, Some(value_expr) => Some( slice_index_model.review( + ctx.ctx, generator .gen_expr(ctx, value_expr)? .unwrap() @@ -2210,6 +2247,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error let index = slice_index_model.review( + ctx.ctx, generator .gen_expr(ctx, subscript_expr)? .unwrap() @@ -2931,7 +2969,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet })); let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; - ndarray_ptr_model.review(v.as_any_value_enum()) + ndarray_ptr_model.review(ctx.ctx, v.as_any_value_enum()) } else { return Ok(None); }; diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index bb822f19..5259d280 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -123,11 +123,12 @@ pub trait CodeGenerator { ctx: &mut CodeGenContext<'ctx, '_>, target: &Expr>, value: ValueEnum<'ctx>, + value_ty: Type, ) -> Result<(), String> where Self: Sized, { - gen_assign(self, ctx, target, value) + gen_assign(self, ctx, target, value, value_ty) } /// Generate code for a while expression. diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs index 4b27ecb9..9b0e3cf0 100644 --- a/nac3core/src/codegen/irrt/util.rs +++ b/nac3core/src/codegen/irrt/util.rs @@ -61,7 +61,7 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { }); let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); - return_model.review(ret.as_any_value_enum()) + return_model.review(self.ctx.ctx, ret.as_any_value_enum()) } // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 13c4830b..5f7628e3 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -10,7 +10,7 @@ MemorySetter a use inkwell::{ context::Context, - types::BasicTypeEnum, + types::{AnyTypeEnum, BasicTypeEnum}, values::{AnyValueEnum, BasicValueEnum}, }; @@ -25,15 +25,19 @@ pub trait ModelValue<'ctx>: Clone + Copy { // Should have been within [`Model`], // but rust object safety requirements made it necessary to // split this interface out -pub trait CanCheckLLVMType { - fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String>; +pub trait CanCheckLLVMType<'ctx> { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String>; } -pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType + Sized { +pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType<'ctx> + Sized { type Value: ModelValue<'ctx>; fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value; + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value; fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> { Pointer { diff --git a/nac3core/src/codegen/model/gep.rs b/nac3core/src/codegen/model/gep.rs index 14ac7cf0..1e201a0e 100644 --- a/nac3core/src/codegen/model/gep.rs +++ b/nac3core/src/codegen/model/gep.rs @@ -1,9 +1,9 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, StructType}, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, StructType}, values::{AnyValueEnum, BasicValue, BasicValueEnum, StructValue}, }; -use itertools::Itertools; +use itertools::{izip, Itertools}; use crate::codegen::CodeGenContext; @@ -20,7 +20,10 @@ pub struct Field { struct FieldLLVM<'ctx> { gep_index: u64, name: &'ctx str, - llvm_type: Box, + llvm_type: BasicTypeEnum<'ctx>, + + // Only CanCheckLLVMType is needed, dont put in the whole `Model<'ctx>` + llvm_type_model: Box + 'ctx>, } pub struct FieldBuilder<'ctx> { @@ -42,46 +45,27 @@ impl<'ctx> FieldBuilder<'ctx> { index } - pub fn add_field>(&mut self, name: &'static str, element: E) -> Field { + pub fn add_field + 'ctx>(&mut self, name: &'static str, element: E) -> Field { let gep_index = self.next_gep_index(); - self.fields.push(FieldLLVM { gep_index, name, llvm_type: element.get_llvm_type(self.ctx) }); + self.fields.push(FieldLLVM { + gep_index, + name, + llvm_type: element.get_llvm_type(self.ctx), + llvm_type_model: Box::new(element), + }); Field { gep_index, name, element } } - pub fn add_field_auto + Default>(&mut self, name: &'static str) -> Field { + pub fn add_field_auto + Default + 'ctx>( + &mut self, + name: &'static str, + ) -> Field { self.add_field(name, E::default()) } } -fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String> -where - A: BasicType<'ctx>, - B: BasicType<'ctx>, -{ - let expected = expected.as_basic_type_enum(); - let got = got.as_basic_type_enum(); - - // Put those logic into here, - // otherwise there is always a fallback reporting on any kind of mismatch - match (expected, got) { - (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { - if expected.get_bit_width() != got.get_bit_width() { - return Err(format!( - "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" - )); - } - } - (expected, got) => { - if expected != got { - return Err(format!("Expected {expected}, got {got}")); - } - } - } - Ok(()) -} - pub trait IsStruct<'ctx>: Clone + Copy { type Fields; @@ -98,14 +82,39 @@ pub trait IsStruct<'ctx>: Clone + Copy { let mut builder = FieldBuilder::new(ctx, self.struct_name()); self.build_fields(&mut builder); // Self::Fields is discarded - let field_types = - builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec(); - ctx.struct_type(&field_types, false).as_basic_type_enum().into_pointer_type().get_el + let field_types = builder.fields.iter().map(|f| f.llvm_type).collect_vec(); + ctx.struct_type(&field_types, false) } - fn check_struct_type(&self) { - // Datatypes behind - // check_basic_types_match + fn check_struct_type( + &self, + ctx: &'ctx Context, + scrutinee: StructType<'ctx>, + ) -> Result<(), String> { + // Details about scrutinee + let scrutinee_field_types = scrutinee.get_field_types(); + + // Details about the defined specifications of this struct + // We will access them through builder + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder); + + // Check # of fields + if builder.fields.len() != scrutinee_field_types.len() { + return Err(format!( + "Expecting struct to have {} field(s), but scrutinee has {} field(s)", + builder.fields.len(), + scrutinee_field_types.len() + )); + } + + // Check the types of each field + // TODO: Traceback? + for (f, scrutinee_field_type) in izip!(builder.fields, scrutinee_field_types) { + f.llvm_type_model.check_llvm_type(ctx, scrutinee_field_type.as_any_type_enum())?; + } + + Ok(()) } } @@ -125,8 +134,18 @@ impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> { } impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel { - fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String> { - todo!() + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + // Check if scrutinee is even a struct type + let AnyTypeEnum::StructType(scrutinee) = scrutinee else { + return Err(format!("Expecting a struct type, but got {scrutinee:?}")); + }; + + // Ok. now check the struct type *thoroughly* + self.0.check_struct_type(ctx, scrutinee) } } @@ -137,8 +156,10 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel { self.0.get_struct_type(ctx).as_basic_type_enum() } - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { - // TODO: check structure + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + // Check that `value` is not some bogus values or an incorrect StructValue + self.check_llvm_type(ctx, value.get_type()).unwrap(); + Struct { structure: self.0, value: value.into_struct_value() } } } diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs index 611c7bee..a790cd54 100644 --- a/nac3core/src/codegen/model/int.rs +++ b/nac3core/src/codegen/model/int.rs @@ -1,6 +1,6 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, IntType}, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue}, }; @@ -8,6 +8,38 @@ use crate::codegen::CodeGenContext; use super::core::*; +fn check_int_llvm_type<'ctx>( + scrutinee: AnyTypeEnum<'ctx>, + expected_int_type: IntType<'ctx>, +) -> Result<(), String> { + // Check if llvm_type is int type + let AnyTypeEnum::IntType(scrutinee) = scrutinee else { + return Err(format!("Expecting an int type but got {scrutinee:?}")); + }; + + // Check bit width + if scrutinee.get_bit_width() != expected_int_type.get_bit_width() { + return Err(format!( + "Expecting an int type of {}-bit(s) but got int type {}-bit(s)", + expected_int_type.get_bit_width(), + scrutinee.get_bit_width() + )); + } + + Ok(()) +} + +fn review_int_llvm_value<'ctx>( + value: AnyValueEnum<'ctx>, + expected_int_type: IntType<'ctx>, +) -> Result, String> { + // Check if value is of int type, error if that is anything else + check_int_llvm_type(value.get_type().as_any_type_enum(), expected_int_type)?; + + // Ok, it is must be an int + Ok(value.into_int_value()) +} + #[derive(Debug, Clone, Copy)] pub struct IntModel<'ctx>(pub IntType<'ctx>); @@ -20,6 +52,16 @@ impl<'ctx> ModelValue<'ctx> for Int<'ctx> { } } +impl<'ctx> CanCheckLLVMType<'ctx> for IntModel<'ctx> { + fn check_llvm_type( + &self, + _ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + check_int_llvm_type(scrutinee, self.0) + } +} + impl<'ctx> Model<'ctx> for IntModel<'ctx> { type Value = Int<'ctx>; @@ -27,9 +69,9 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> { self.0.as_basic_type_enum() } - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { let int = value.into_int_value(); - assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width()); + self.check_llvm_type(ctx, int.get_type().as_any_type_enum()).unwrap(); Int(int) } } @@ -90,6 +132,47 @@ pub struct FixedInt<'ctx, T: IsFixedInt> { pub value: IntValue<'ctx>, } +// Default instance is to enable `FieldBuilder::add_field_auto` +pub trait IsFixedInt: Clone + Copy + Default { + fn get_int_type(ctx: &Context) -> IntType<'_>; + fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type +} + +impl<'ctx, T: IsFixedInt> ModelValue<'ctx> for FixedInt<'ctx, T> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, T: IsFixedInt> CanCheckLLVMType<'ctx> for FixedIntModel { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + check_int_llvm_type(scrutinee, T::get_int_type(ctx)) + } +} + +impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel { + type Value = FixedInt<'ctx, T>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + T::get_int_type(ctx).as_basic_type_enum() + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + let value = review_int_llvm_value(value, T::get_int_type(ctx)).unwrap(); + FixedInt { int: self.0, value } + } +} + +impl<'ctx, T: IsFixedInt> FixedIntModel { + pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> { + FixedInt { int: self.0, value: T::get_int_type(ctx).const_int(value, false) } + } +} + impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> { pub fn to_int(self) -> Int<'ctx> { Int(self.value) @@ -111,37 +194,7 @@ impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> { } } -// Default instance is to enable `FieldBuilder::add_field_auto` -pub trait IsFixedInt: Clone + Copy + Default { - fn get_int_type(ctx: &Context) -> IntType<'_>; - fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type -} - -impl<'ctx, T: IsFixedInt> ModelValue<'ctx> for FixedInt<'ctx, T> { - fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { - self.value.as_basic_value_enum() - } -} - -impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel { - type Value = FixedInt<'ctx, T>; - - fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { - T::get_int_type(ctx).as_basic_type_enum() - } - - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { - let value = value.into_int_value(); - assert_eq!(value.get_type().get_bit_width(), T::get_bit_width()); - FixedInt { int: self.0, value } - } -} - -impl<'ctx, T: IsFixedInt> FixedIntModel { - pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> { - FixedInt { int: self.0, value: T::get_int_type(ctx).const_int(value, false) } - } -} +// Some pre-defined fixed ints #[derive(Debug, Clone, Copy, Default)] pub struct Bool; diff --git a/nac3core/src/codegen/model/pointer.rs b/nac3core/src/codegen/model/pointer.rs index 969aa124..add6bbf5 100644 --- a/nac3core/src/codegen/model/pointer.rs +++ b/nac3core/src/codegen/model/pointer.rs @@ -1,6 +1,6 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum}, values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue}, AddressSpace, }; @@ -31,7 +31,7 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value { let val = ctx.builder.build_load(self.value, name).unwrap(); - self.element.review(val.as_any_value_enum()) + self.element.review(ctx.ctx, val.as_any_value_enum()) } pub fn to_opaque(self) -> OpaquePointer<'ctx> { @@ -59,6 +59,25 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { } } +impl<'ctx, E: Model<'ctx>> CanCheckLLVMType<'ctx> for PointerModel { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + // Check if scrutinee is even a PointerValue + let AnyTypeEnum::PointerType(scrutinee) = scrutinee else { + return Err(format!("Expecting a pointer value, but got {scrutinee:?}")); + }; + + // Check the type of what the pointer is pointing at + // TODO: This will be deprecated by inkwell > llvm14 because `get_element_type()` will be gone + self.0.check_llvm_type(ctx, scrutinee.get_element_type())?; // TODO: Include backtrace? + + Ok(()) + } +} + impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel { type Value = Pointer<'ctx, E>; @@ -66,7 +85,9 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel { self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() } - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + self.check_llvm_type(ctx, value.get_type()).unwrap(); + // TODO: Check get_element_type()? for LLVM 14 at least... Pointer { element: self.0, value: value.into_pointer_value() } } @@ -85,6 +106,21 @@ impl<'ctx> ModelValue<'ctx> for OpaquePointer<'ctx> { } } +impl<'ctx> CanCheckLLVMType<'ctx> for OpaquePointerModel { + fn check_llvm_type( + &self, + _ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + // OpaquePointerModel only cares that it is a pointer, + // but not what the pointer is pointing at + match scrutinee { + AnyTypeEnum::PointerType(_) => Ok(()), + _ => Err(format!("Expecting a pointer type, but got {scrutinee:?}")), + } + } +} + impl<'ctx> Model<'ctx> for OpaquePointerModel { type Value = OpaquePointer<'ctx>; @@ -92,11 +128,11 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel { ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum() } - fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { - let ptr = value.into_pointer_value(); - // TODO: remove this check once LLVM pointers do not have `get_element_type()` - assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8); - OpaquePointer(ptr) + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + // Check if value is even of a pointer type + self.check_llvm_type(ctx, value.get_type()).unwrap(); + + OpaquePointer(value.into_pointer_value()) } } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index ac965109..759a577c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -13,7 +13,7 @@ use crate::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{iter_type_vars, FunSignature, Type, TypeEnum}, }, }; use inkwell::{ @@ -202,6 +202,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, target: &Expr>, value: ValueEnum<'ctx>, + value_ty: Type, ) -> Result<(), String> { /* To handle assignment statements `target = value`, with @@ -213,8 +214,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( - Case 3. Indexed ndarray assignment `ndarray.__setitem__` - e.g., `my_ndarray[::-1, :] = 3`, `my_ndarray[:, 3::-1] = their_ndarray[10::2]` - NOTE: Technically speaking, if `target` is sliced in such as way that it is referencing a - single element/scalar, we *could* implement gen_store_target for this special case; - but it is much, *much* simpler to generalize all indexed ndarray assignment without + single element/scalar, we *could* implement gen_store_target for this special case + (to point to the raw address of that scalar in the ndarray); but it is much, + *much* simpler to generalize all indexed ndarray assignment without special handling on that edgecase. - Otherwise, use `gen_store_target` */ @@ -230,11 +232,13 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( }; for (i, elt) in elts.iter().enumerate() { + let elem_ty = elt.custom.unwrap(); + let v = ctx .builder .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") .unwrap(); - generator.gen_assign(ctx, elt, v.into())?; + generator.gen_assign(ctx, elt, v.into(), elem_ty)?; } return Ok(()); // Terminate @@ -311,11 +315,26 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ctx, generator, target.custom.unwrap(), - ); + )?; - // let value = value.to_basic_value_enum(ctx, generator, value); - - todo!(); + match &*ctx.unifier.get_ty(value_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + // `value` is an `ndarray[dtype, ndims]` + todo!() + } + _ => { + // TODO: Inferencer's assignment forces `target` and `value` to have the same type + // NOTE: gen_assign() has already been extended, I will keep it in place + // in participation for when this is extended to be no longer the case. + todo!("support scalar assignment") + // panic!( + // "Unsupported ndarray assignment value: {}", + // ctx.unifier.stringify(value_ty) + // ); + } + } return Ok(()); // Terminate } @@ -325,7 +344,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( } } - // None of the cases match. We should actually use `gen_store_target`. + // The assignment expression matches none of the special cases. + // We should actually use `gen_store_target`. let name = if let ExprKind::Name { id, .. } = &target.node { format!("{id}.addr") } else { @@ -369,9 +389,6 @@ pub fn gen_for( let orelse_bb = if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; - // Whether the iterable is a range() expression - let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); - // The BB containing the increment expression let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); // The BB containing the loop condition check @@ -385,108 +402,136 @@ pub fn gen_for( } else { return Ok(()); }; - if is_iterable_range_expr { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); - // Internal variable for loop; Cannot be assigned - let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; - // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed - let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? - else { - unreachable!() - }; - let (start, stop, step) = destructure_range(ctx, iter_val); - - ctx.builder.build_store(i, start).unwrap(); - - // Check "If step is zero, ValueError is raised." - let rangenez = - ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap(); - ctx.make_assert( - generator, - rangenez, - "ValueError", - "range() arg 3 must not be zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + // The implementation of the for loop logic depends on + // the typechecker type of `iter`. + let iter_ty = iter.custom.unwrap(); + match &*ctx.unifier.get_ty(iter_ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { + // `iter` is a `List[T]`, and `T` is the element type + + // Get the `T` out of `List[T]` - it is defined to be the 1st param. + let list_elem_ty = iter_type_vars(params).nth(0).unwrap().ty; + + // Implementation + let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; + ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap(); + let len = ctx + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero, int32.const_int(1, false)], + Some("len"), + ) + .into_int_value(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + ctx.builder.position_at_end(cond_bb); - ctx.builder - .build_conditional_branch( - gen_in_range_check( - ctx, - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - stop, - step, - ), - body_bb, - orelse_bb, + let index = ctx + .builder + .build_load(index_addr, "for.index") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); + ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); + + ctx.builder.position_at_end(incr_bb); + let index = + ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); + ctx.builder.build_store(index_addr, inc).unwrap(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + + ctx.builder.position_at_end(body_bb); + let arr_ptr = ctx + .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) + .into_pointer_value(); + let index = ctx + .builder + .build_load(index_addr, "for.index") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); + generator.gen_assign(ctx, target, val.into(), list_elem_ty)?; + generator.gen_block(ctx, body.iter())?; + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => + { + // `iter` is a `range(start, stop, step)`, and `int32` is the element type + + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + // Internal variable for loop; Cannot be assigned + let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; + // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed + let Some(target_i) = + generator.gen_store_target(ctx, target, Some("for.target.addr"))? + else { + unreachable!() + }; + let (start, stop, step) = destructure_range(ctx, iter_val); + + ctx.builder.build_store(i, start).unwrap(); + + // Check "If step is zero, ValueError is raised." + let rangenez = ctx + .builder + .build_int_compare(IntPredicate::NE, step, int32.const_zero(), "") + .unwrap(); + ctx.make_assert( + generator, + rangenez, + "ValueError", + "range() arg 3 must not be zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + + { + ctx.builder.position_at_end(cond_bb); + ctx.builder + .build_conditional_branch( + gen_in_range_check( + ctx, + ctx.builder + .build_load(i, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + stop, + step, + ), + body_bb, + orelse_bb, + ) + .unwrap(); + } + + ctx.builder.position_at_end(incr_bb); + let next_i = ctx + .builder + .build_int_add( + ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), + step, + "inc", ) .unwrap(); + ctx.builder.build_store(i, next_i).unwrap(); + ctx.builder.build_unconditional_branch(cond_bb).unwrap(); + + ctx.builder.position_at_end(body_bb); + ctx.builder + .build_store( + target_i, + ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), + ) + .unwrap(); + generator.gen_block(ctx, body.iter())?; + } + _ => { + panic!("unsupported iterator type in for loop: {}", ctx.unifier.stringify(iter_ty)) } - - ctx.builder.position_at_end(incr_bb); - let next_i = ctx - .builder - .build_int_add( - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - step, - "inc", - ) - .unwrap(); - ctx.builder.build_store(i, next_i).unwrap(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - - ctx.builder.position_at_end(body_bb); - ctx.builder - .build_store( - target_i, - ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), - ) - .unwrap(); - generator.gen_block(ctx, body.iter())?; - } else { - let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; - ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap(); - let len = ctx - .build_gep_and_load( - iter_val.into_pointer_value(), - &[zero, int32.const_int(1, false)], - Some("len"), - ) - .into_int_value(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - - ctx.builder.position_at_end(cond_bb); - let index = ctx - .builder - .build_load(index_addr, "for.index") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); - ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); - - ctx.builder.position_at_end(incr_bb); - let index = - ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); - ctx.builder.build_store(index_addr, inc).unwrap(); - ctx.builder.build_unconditional_branch(cond_bb).unwrap(); - - ctx.builder.position_at_end(body_bb); - let arr_ptr = ctx - .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) - .into_pointer_value(); - let index = ctx - .builder - .build_load(index_addr, "for.index") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); - generator.gen_assign(ctx, target, val.into())?; - generator.gen_block(ctx, body.iter())?; } for (k, (_, _, counter)) in &var_assignment { @@ -1629,14 +1674,18 @@ pub fn gen_stmt( } StmtKind::AnnAssign { target, value, .. } => { if let Some(value) = value { + let value_ty = value.custom.unwrap(); let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; - generator.gen_assign(ctx, target, value)?; + generator.gen_assign(ctx, target, value, value_ty)?; } } StmtKind::Assign { targets, value, .. } => { + // TODO: Is the implementation wrong? It looks very strange. + let value_ty = value.custom.unwrap(); let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; + for target in targets { - generator.gen_assign(ctx, target, value.clone())?; + generator.gen_assign(ctx, target, value.clone(), value_ty)?; } } StmtKind::Continue { .. } => { @@ -1650,6 +1699,7 @@ pub fn gen_stmt( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { + let value_ty = value.custom.unwrap(); let value = gen_binop_expr( generator, ctx, @@ -1658,7 +1708,7 @@ pub fn gen_stmt( value, stmt.location, )?; - generator.gen_assign(ctx, target, value.unwrap())?; + generator.gen_assign(ctx, target, value.unwrap(), value_ty)?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ef4038f7..c4aabc0f 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> { let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet })); let ndarray_ptr = - ndarray_ptr_model.review(arg.as_any_value_enum()); + ndarray_ptr_model.review(ctx.ctx, arg.as_any_value_enum()); // Calculate len // NOTE: Unsized object is asserted in IRRT