From 529fa67855e8f44373d90d7e7b5c7a08b9745f56 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 14:01:10 +0800 Subject: [PATCH 01/13] [core] codegen: Add bool_to_int_type to replace bool_to_{i1,i8} Unifies the implementation for both functions. --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/generator.rs | 20 +++++++++--- nac3core/src/codegen/mod.rs | 52 +++++++++++-------------------- 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 20d296e4..f4e03d04 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2001,7 +2001,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ).into_int_value(); let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(result, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { result } diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 620ede0e..42c7c71b 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -7,7 +7,7 @@ use inkwell::{ use nac3parser::ast::{Expr, Stmt, StrRef}; -use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext}; +use super::{bool_to_int_type, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext}; use crate::{ symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, @@ -248,22 +248,32 @@ pub trait CodeGenerator { gen_block(self, ctx, stmts) } - /// See [`bool_to_i1`]. + /// Converts the value of a boolean-like value `bool_value` into an `i1`. fn bool_to_i1<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { - bool_to_i1(&ctx.builder, bool_value) + self.bool_to_int_type(ctx, bool_value, ctx.ctx.bool_type()) } - /// See [`bool_to_i8`]. + /// Converts the value of a boolean-like value `bool_value` into an `i8`. fn bool_to_i8<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { - bool_to_i8(&ctx.builder, ctx.ctx, bool_value) + self.bool_to_int_type(ctx, bool_value, ctx.ctx.i8_type()) + } + + /// See [`bool_to_int_type`]. + fn bool_to_int_type<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + bool_value: IntValue<'ctx>, + ty: IntType<'ctx>, + ) -> IntValue<'ctx> { + bool_to_int_type(&ctx.builder, bool_value, ty) } } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index a188d1c3..f1b9cfb9 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -933,7 +933,7 @@ pub fn gen_func_impl< let param_val = param.into_int_value(); if expected_ty.get_bit_width() == 8 && param_val.get_type().get_bit_width() == 1 { - bool_to_i8(&builder, context, param_val) + bool_to_int_type(&builder, param_val, context.i8_type()) } else { param_val } @@ -1103,43 +1103,29 @@ pub fn gen_func<'ctx, G: CodeGenerator>( }) } -/// Converts the value of a boolean-like value `bool_value` into an `i1`. -fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntValue<'ctx> { - if bool_value.get_type().get_bit_width() == 1 { - bool_value - } else { - builder - .build_int_compare( - IntPredicate::NE, - bool_value, - bool_value.get_type().const_zero(), - "tobool", - ) - .unwrap() - } -} - -/// Converts the value of a boolean-like value `bool_value` into an `i8`. -fn bool_to_i8<'ctx>( +/// Converts the value of a boolean-like value `value` into an arbitrary [`IntType`]. +/// +/// This has the same semantics as `(ty)(value != 0)` in C. +/// +/// The returned value is guaranteed to either be `0` or `1`, except for `ty == i1` where only the +/// least-significant bit would be guaranteed to be `0` or `1`. +fn bool_to_int_type<'ctx>( builder: &Builder<'ctx>, - ctx: &'ctx Context, - bool_value: IntValue<'ctx>, + value: IntValue<'ctx>, + ty: IntType<'ctx>, ) -> IntValue<'ctx> { - let value_bits = bool_value.get_type().get_bit_width(); - match value_bits { - 8 => bool_value, - 1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool").unwrap(), - _ => bool_to_i8( + // i1 -> i1 : %value ; no-op + // i1 -> i : zext i1 %value to i ; guaranteed to be 0 or 1 - see docs + // i -> i: zext i1 (icmp eq i %value, 0) to i ; same as i -> i1 -> i + match (value.get_type().get_bit_width(), ty.get_bit_width()) { + (1, 1) => value, + (1, _) => builder.build_int_z_extend(value, ty, "frombool").unwrap(), + _ => bool_to_int_type( builder, - ctx, builder - .build_int_compare( - IntPredicate::NE, - bool_value, - bool_value.get_type().const_zero(), - "", - ) + .build_int_compare(IntPredicate::NE, value, value.get_type().const_zero(), "tobool") .unwrap(), + ty, ), } } From 0d8cb909dda0f3082e3279390e7d6ac27756a0fa Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 10:38:19 +0800 Subject: [PATCH 02/13] [core] codegen/expr: Fix and use gen_unaryop_expr for boolean not ops While refactoring, I ran into the issue where `!true == true`, which was caused by the same upper 7-bit of booleans being undefined issue that was encountered before. It turns out the implementation in `gen_unaryop_expr` is also inadequate, as `(~v & (i1) 0x1)`` will still leave upper 7 bits undefined (for whatever reason). This commit fixes this issue once and for all by using a combination of `icmp` + `zext` to ensure that the resulting value must be `0 | 1`, and refactor to use that whenever we need to invert boolean values. --- nac3core/src/codegen/expr.rs | 39 ++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f4e03d04..53aa5f14 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1704,11 +1704,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(if ty == ctx.primitives.bool { let val = val.into_int_value(); if op == ast::Unaryop::Not { - let not = ctx.builder.build_not(val, "not").unwrap(); - let not_bool = - ctx.builder.build_and(not, not.get_type().const_int(1, false), "").unwrap(); + let not = ctx + .builder + .build_int_compare(IntPredicate::EQ, val, val.get_type().const_zero(), "not") + .unwrap(); - not_bool.into() + generator.bool_to_int_type(ctx, not, val.get_type()).into() } else { let llvm_i32 = ctx.ctx.i32_type(); @@ -2001,7 +2002,18 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ).into_int_value(); let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(result, "").unwrap() + gen_unaryop_expr_with_values( + generator, + ctx, + Unaryop::Not, + (&Some(ctx.primitives.bool), result.into()), + ) + .transpose() + .unwrap() + .and_then(|res| { + res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + })? + .into_int_value() } else { result } @@ -2248,8 +2260,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .unwrap() .and_then(|v| { v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - }) - .map(BasicValueEnum::into_int_value)?; + })? + .into_int_value(); Ok(ctx.builder.build_not( generator.bool_to_i1(ctx, cmp), @@ -2285,7 +2297,18 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( // Invert the final value if __ne__ if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + gen_unaryop_expr_with_values( + generator, + ctx, + Unaryop::Not, + (&Some(ctx.primitives.bool), cmp_phi.into()) + ) + .transpose() + .unwrap() + .and_then(|res| { + res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + })? + .into_int_value() } else { cmp_phi } From c37c7e8975ea098b1b2ddcc1f2cab6138b958533 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 10:39:21 +0800 Subject: [PATCH 03/13] [core] codegen/expr: Simplify `gen_*_expr_with_values` return value These functions always return `BasicValueEnum` because they operate on `BasicValueEnum`s, and they also always return a value. --- nac3core/src/codegen/expr.rs | 88 ++++++------------- nac3core/src/codegen/values/ndarray/matmul.rs | 8 +- 2 files changed, 30 insertions(+), 66 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 53aa5f14..986ed992 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op: Binop, right: (&Option, BasicValueEnum<'ctx>), loc: Location, -) -> Result>, String> { +) -> Result, String> { let (left_ty, left_val) = left; let (right_ty, right_val) = right; @@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true)) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false)) } else if [Operator::LShift, Operator::RShift].contains(&op.base) { let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed)) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into())) + Ok(ctx.gen_float_ops(op.base, left_val, right_val)) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert_eq!(op.base, Operator::Pow); @@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( right_val.into_int_value(), Some("f_pow_i"), ); - Ok(Some(res.into())) + Ok(res.into()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } _ => todo!("Operator not supported"), @@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let result = left .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) .split_unsized(generator, ctx); - Ok(Some(result.to_basic_value_enum().into())) + Ok(result.to_basic_value_enum()) } else { // For other operations, they are all elementwise operations. @@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ty2_dtype), right_value), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, common_dtype)?; + )?; Ok(result) }) .unwrap(); - Ok(Some(result.as_abi_value(ctx).into())) + Ok(result.as_abi_value(ctx).into()) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( (&signature, fun_id), vec![(None, right_val.into())], ) - .map(|f| f.map(Into::into)) + .map(Option::unwrap) + .map(BasicValueEnum::into) } } @@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( (&right.custom, right_val), loc, ) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a unary operator expression using the [`Type`] and @@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, op: ast::Unaryop, operand: (&Option, BasicValueEnum<'ctx>), -) -> Result>, String> { +) -> Result, String> { let (ty, val) = operand; let ty = ctx.unifier.get_representative(ty.unwrap()); - Ok(Some(if ty == ctx.primitives.bool { + Ok(if ty == ctx.primitives.bool { let val = val.into_int_value(); if op == ast::Unaryop::Not { let not = ctx @@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(), ), )? - .unwrap() } } else if [ ctx.primitives.int32, @@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx, NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, |generator, ctx, scalar| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? - .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) - .unwrap() + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar)) }, )?; mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() - })) + }) } /// Generates LLVM IR for a unary operator expression. @@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( }; gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a comparison operator expression using the [`Type`] and @@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( left: (Option, BasicValueEnum<'ctx>), ops: &[ast::Cmpop], comparators: &[(Option, BasicValueEnum<'ctx>)], -) -> Result>, String> { +) -> Result, String> { debug_assert_eq!(comparators.len(), ops.len()); if comparators.len() == 1 { @@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( (Some(left_ty_dtype), left_scalar), &[op], &[(Some(right_ty_dtype), right_scalar)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, )?; Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) }, )?; - return Ok(Some(result_ndarray.as_abi_value(ctx).into())); + return Ok(result_ndarray.as_abi_value(ctx).into()); } } @@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx, Unaryop::Not, (&Some(ctx.primitives.bool), result.into()), - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + )?.into_int_value() } else { result } @@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[Cmpop::Eq], &[(Some(right_elem_ty), right)], )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool) - .unwrap() .into_int_value(); gen_if_callback( @@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( Unaryop::Not, (&Some(ctx.primitives.bool), acc.into()), )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool)? .into_int_value() } else { acc @@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[op], &[(Some(right_ty), right_elem)], ) - .transpose() - .unwrap() - .and_then(|v| { - v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value(); + .map(BasicValueEnum::into_int_value)?; Ok(ctx.builder.build_not( generator.bool_to_i1(ctx, cmp), @@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, Unaryop::Not, - (&Some(ctx.primitives.bool), cmp_phi.into()) - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + (&Some(ctx.primitives.bool), cmp_phi.into()), + )?.into_int_value() } else { cmp_phi } @@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) - })?; + })?.unwrap(); - Ok(Some(match cmp_val { - Some(v) => v.into(), - None => return Ok(None), - })) + Ok(cmp_val.into()) } /// Generates LLVM IR for a comparison operator expression. @@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ops, comparator_vals.as_slice(), ) + .map(|res| Some(res.into())) } /// See [`CodeGenerator::gen_expr`]. diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index f12d36c1..cc8d059a 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Mult), (&Some(rhs_dtype), b_kj), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; // dst_[...]ij += x let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); @@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Add), (&Some(dst_dtype), x), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); Ok(()) From a078481cd2c7e1b51aa991d4fd5c00b6f48a6818 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 13:27:45 +0800 Subject: [PATCH 04/13] [meta] Minor simplification for PrimStore extraction --- nac3artiq/src/codegen.rs | 12 +++++------- nac3core/src/toplevel/builtins.rs | 4 +--- nac3core/src/toplevel/composer.rs | 3 +-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index fb6992b9..d086420c 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -41,7 +41,10 @@ use nac3core::{ numpy::unpack_ndarray_var_tys, DefinitionId, GenCall, }, - typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, + }, }; /// The parallelism mode within a block. @@ -389,12 +392,7 @@ fn gen_rpc_tag( ) -> Result<(), String> { use nac3core::typecheck::typedef::TypeEnum::*; - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let bool = ctx.primitives.bool; - let str = ctx.primitives.str; - let none = ctx.primitives.none; + let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives; if ctx.unifier.unioned(ty, int32) { buffer.push(b'i'); diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index eff614e5..c9b5d222 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -36,9 +36,7 @@ pub fn get_exn_constructor( unifier: &mut Unifier, primitives: &PrimitiveStore, ) -> (TopLevelDef, TopLevelDef, Type, Type) { - let int32 = primitives.int32; - let int64 = primitives.int64; - let string = primitives.str; + let PrimitiveStore { int32, int64, str: string, .. } = *primitives; let exception_fields = make_exception_fields(int32, int64, string); let exn_cons_args = vec![ FuncArg { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a6a0ce76..50d6dd2e 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1521,8 +1521,7 @@ impl TopLevelComposer { .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { // create constructor for these classes - let string = primitives_ty.str; - let int64 = primitives_ty.int64; + let PrimitiveStore { str: string, int64, .. } = *primitives_ty; let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { From 2df22e29f738ab214893c6d2f0e083618d9a232d Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 11:18:08 +0800 Subject: [PATCH 05/13] [core] codegen: Simplify TupleType::construct --- nac3core/src/codegen/types/tuple.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index ea66feb4..3facf5ea 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -115,15 +115,8 @@ impl<'ctx> TupleType<'ctx> { /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. #[must_use] - pub fn construct( - &self, - ctx: &CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - self.map_struct_value( - Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), - name, - ) + pub fn construct(&self, name: Option<&'ctx str>) -> >::Value { + self.map_struct_value(self.as_abi_type().const_zero(), name) } /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of @@ -143,7 +136,7 @@ impl<'ctx> TupleType<'ctx> { .enumerate() .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); - let mut value = self.construct(ctx, name); + let mut value = self.construct(name); for (i, val) in values.into_iter().enumerate() { value.store_element(ctx, i as u32, val); } From 69542c38a2bbac50401c111c29c907d04da69641 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 11:07:01 +0800 Subject: [PATCH 06/13] [core] codegen: Rename TupleValue::{store,load} -> {insert,extract} Better matches the underlying operation. --- nac3core/src/codegen/types/tuple.rs | 2 +- nac3core/src/codegen/values/ndarray/shape.rs | 2 +- nac3core/src/codegen/values/tuple.rs | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 3facf5ea..90abeb34 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -138,7 +138,7 @@ impl<'ctx> TupleType<'ctx> { let mut value = self.construct(name); for (i, val) in values.into_iter().enumerate() { - value.store_element(ctx, i as u32, val); + value.insert_element(ctx, i as u32, val); } value diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index b3331b6f..69e8b50b 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -106,7 +106,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( for i in 0..input_seq.get_type().num_elements() { // Get the i-th element off of the tuple and load it into `result`. - let int = input_seq.load_element(ctx, i).into_int_value(); + let int = input_seq.extract_element(ctx, i).into_int_value(); let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); unsafe { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 320e2190..1f124c8c 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -45,7 +45,7 @@ impl<'ctx> TupleValue<'ctx> { } /// Stores a value into the tuple element at the given `index`. - pub fn store_element( + pub fn insert_element( &mut self, ctx: &CodeGenContext<'ctx, '_>, index: u32, @@ -63,7 +63,11 @@ impl<'ctx> TupleValue<'ctx> { } /// Loads a value from the tuple element at the given `index`. - pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + pub fn extract_element( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + ) -> BasicValueEnum<'ctx> { ctx.builder .build_extract_value( self.value, From 67f42185de653fc7a7f315dc6ae1c007bfdbda16 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 10:37:55 +0800 Subject: [PATCH 07/13] [core] codegen/expr: Add concrete ndims value to error message --- nac3core/src/codegen/expr.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 986ed992..cd9b87d5 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -43,7 +43,7 @@ use super::{ use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ - helper::{arraylike_flatten_element_type, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef, }, @@ -1775,10 +1775,13 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { + let ndims = extract_ndims(&ctx.unifier, ty); + codegen_unreachable!( ctx, - "ufunc {} not supported for ndarray[bool, N]", + "ufunc {} not supported for ndarray[bool, {}]", op.op_info().method_name, + ndims, ) } } else { From 0a761cb2637ef9bd97b92362ecc214884883851b Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 15:23:07 +0800 Subject: [PATCH 08/13] [core] Use more TupleType constructors --- nac3core/src/codegen/expr.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index cd9b87d5..bab3b75d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType}, + types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -180,23 +180,10 @@ impl<'ctx> CodeGenContext<'ctx, '_> { SymbolValue::Tuple(ls) => { let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); let fields = vals.iter().map(BasicValueEnum::get_type).collect_vec(); - let ty = self.ctx.struct_type(&fields, false); - let ptr = gen_var(self, ty.into(), Some("tuple")).unwrap(); - let zero = self.ctx.i32_type().const_zero(); - unsafe { - for (i, val) in vals.into_iter().enumerate() { - let p = self - .builder - .build_in_bounds_gep( - ptr, - &[zero, self.ctx.i32_type().const_int(i as u64, false)], - "elemptr", - ) - .unwrap(); - self.builder.build_store(p, val).unwrap(); - } - } - self.builder.build_load(ptr, "tup_val").unwrap() + TupleType::new(self, &fields) + .construct_from_objects(self, vals, Some("tup_val")) + .as_abi_value(self) + .into() } SymbolValue::OptionSome(v) => { let ty = match self.unifier.get_ty_immutable(ty).as_ref() { From 35e9c5b38e2bfc240f990b789bca632ed795e1d6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 15:43:48 +0800 Subject: [PATCH 09/13] [core] codegen: Add String{Type,Value} --- nac3core/src/codegen/expr.rs | 59 ++------- nac3core/src/codegen/irrt/string.rs | 26 ++-- nac3core/src/codegen/mod.rs | 16 +-- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/string.rs | 177 ++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/string.rs | 87 +++++++++++++ 7 files changed, 290 insertions(+), 79 deletions(-) create mode 100644 nac3core/src/codegen/types/string.rs create mode 100644 nac3core/src/codegen/values/string.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index bab3b75d..c398ed97 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, + types::{ndarray::NDArrayType, ListType, RangeType, StringType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -168,14 +168,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { SymbolValue::Bool(v) => self.ctx.i8_type().const_int(u64::from(*v), true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Str(v) => { - let str_ptr = self - .builder - .build_global_string_ptr(v, "const") - .map(|v| v.as_pointer_value().into()) - .unwrap(); - let size = self.get_size_type().const_int(v.len() as u64, false); - let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); - ty.const_named_struct(&[str_ptr, size.into()]).into() + StringType::new(self).construct_constant(self, v, None).as_abi_value(self).into() } SymbolValue::Tuple(ls) => { let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); @@ -308,15 +301,10 @@ impl<'ctx> CodeGenContext<'ctx, '_> { if let Some(v) = self.const_strings.get(v) { Some(*v) } else { - let str_ptr = self - .builder - .build_global_string_ptr(v, "const") - .map(|v| v.as_pointer_value().into()) - .unwrap(); - let size = self.get_size_type().const_int(v.len() as u64, false); - let ty = self.get_llvm_type(generator, self.primitives.str); - let val = - ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); + let val = StringType::new(self) + .construct_constant(self, v, None) + .as_abi_value(self) + .into(); self.const_strings.insert(v.to_string(), val); Some(val) } @@ -1950,39 +1938,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let lhs = lhs.into_struct_value(); - let rhs = rhs.into_struct_value(); + let llvm_str = StringType::new(ctx); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = ctx.get_size_type(); + let lhs = llvm_str.map_struct_value(lhs.into_struct_value(), None); + let rhs = llvm_str.map_struct_value(rhs.into_struct_value(), None); - let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); - ctx.builder.build_store(plhs, lhs).unwrap(); - let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); - ctx.builder.build_store(prhs, rhs).unwrap(); - - let lhs_ptr = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_usize.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - let lhs_len = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], - None, - ).into_int_value(); - - let rhs_ptr = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_usize.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - let rhs_len = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], - None, - ).into_int_value(); - let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); + let result = call_string_eq(ctx, lhs, rhs); if *op == Cmpop::NotEq { gen_unaryop_expr_with_values( generator, diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index e015570a..c7e4eebf 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,26 +1,15 @@ -use inkwell::{ - values::{BasicValueEnum, IntValue, PointerValue}, - AddressSpace, -}; +use inkwell::values::{BasicValueEnum, IntValue}; use super::get_usize_dependent_function_name; -use crate::codegen::{expr::infer_and_call_function, CodeGenContext}; +use crate::codegen::{expr::infer_and_call_function, values::StringValue, CodeGenContext}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx>( ctx: &CodeGenContext<'ctx, '_>, - str1_ptr: PointerValue<'ctx>, - str1_len: IntValue<'ctx>, - str2_ptr: PointerValue<'ctx>, - str2_len: IntValue<'ctx>, + str1: StringValue<'ctx>, + str2: StringValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let llvm_usize = ctx.get_size_type(); - assert_eq!(str1_ptr.get_type(), llvm_pi8); - assert_eq!(str1_len.get_type(), llvm_usize); - assert_eq!(str2_ptr.get_type(), llvm_pi8); - assert_eq!(str2_len.get_type(), llvm_usize); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); @@ -28,7 +17,12 @@ pub fn call_string_eq<'ctx>( ctx, &func_name, Some(llvm_i1.into()), - &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + &[ + str1.extract_ptr(ctx).into(), + str1.extract_len(ctx).into(), + str2.extract_ptr(ctx).into(), + str2.extract_len(ctx).into(), + ], Some("str_eq_call"), None, ) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f1b9cfb9..5b6fa215 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -43,7 +43,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, StringType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -786,19 +786,7 @@ pub fn gen_func_impl< (primitives.float, context.f64_type().into()), (primitives.bool, context.i8_type().into()), (primitives.str, { - let name = "str"; - match module.get_struct_type(name) { - None => { - let str_type = context.opaque_struct_type("str"); - let fields = [ - context.i8_type().ptr_type(AddressSpace::default()).into(), - generator.get_size_type(context).into(), - ]; - str_type.set_body(&fields, false); - str_type.into() - } - Some(t) => t.as_basic_type_enum(), - } + StringType::new_with_generator(generator, context).as_abi_type().into() }), (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index abeab5ba..bceb8040 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -27,11 +27,13 @@ use super::{ }; pub use list::*; pub use range::*; +pub use string::*; pub use tuple::*; mod list; pub mod ndarray; mod range; +mod string; pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/types/string.rs b/nac3core/src/codegen/types/string.rs new file mode 100644 index 00000000..eae275da --- /dev/null +++ b/nac3core/src/codegen/types/string.rs @@ -0,0 +1,177 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{GlobalValue, IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use super::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, +}; +use crate::codegen::{values::StringValue, CodeGenContext, CodeGenerator}; + +/// Proxy type for a `str` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct StringType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct StringStructFields<'ctx> { + /// Pointer to the first character of the string. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub ptr: StructField<'ctx, PointerValue<'ctx>>, + + /// Length of the string. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> StringType<'ctx> { + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(llvm_usize: IntType<'ctx>) -> StringStructFields<'ctx> { + StringStructFields::new(llvm_usize.get_context(), llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `str`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> StructType<'ctx> { + const NAME: &str = "str"; + + if let Some(t) = ctx.get_struct_type(NAME) { + t + } else { + let str_ty = ctx.opaque_struct_type(NAME); + let field_tys = Self::fields(llvm_usize).into_iter().map(|field| field.1).collect_vec(); + str_ty.set_body(&field_tys, false); + str_ty + } + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_str = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_str, llvm_usize } + } + + /// Creates an instance of [`StringType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`StringType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`StringType`] from a [`StructType`] representing a `str`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ty, llvm_usize).is_ok()); + + Self { ty, llvm_usize } + } + + /// Creates an [`StringType`] from a [`PointerType`] representing a `str`. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize) + } + + /// Returns the fields present in this [`StringType`]. + #[must_use] + pub fn get_fields(&self) -> StringStructFields<'ctx> { + Self::fields(self.llvm_usize) + } + + /// Constructs a global constant string. + #[must_use] + pub fn construct_constant( + &self, + ctx: &CodeGenContext<'ctx, '_>, + v: &str, + name: Option<&'ctx str>, + ) -> StringValue<'ctx> { + let str_ptr = ctx + .builder + .build_global_string_ptr(v, "const") + .map(GlobalValue::as_pointer_value) + .unwrap(); + let size = ctx.get_size_type().const_int(v.len() as u64, false); + self.map_struct_value( + self.as_abi_type().const_named_struct(&[str_ptr.into(), size.into()]), + name, + ) + } + + /// Converts an existing value into a [`StringValue`]. + #[must_use] + pub fn map_struct_value( + &self, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } + + /// Converts an existing value into a [`StringValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(ctx, value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for StringType<'ctx> { + type ABI = StructType<'ctx>; + type Base = StructType<'ctx>; + type Value = StringValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected structure type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + check_struct_type_matches_fields(Self::fields(llvm_usize), ty, "str", &[]) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: StringType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 90f327e0..cf125fee 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -4,12 +4,14 @@ use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; +pub use string::*; pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +mod string; pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/values/string.rs b/nac3core/src/codegen/values/string.rs new file mode 100644 index 00000000..a4c8beac --- /dev/null +++ b/nac3core/src/codegen/values/string.rs @@ -0,0 +1,87 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, +}; + +use crate::codegen::{ + types::{structure::StructField, StringType}, + values::ProxyValue, + CodeGenContext, +}; + +/// Proxy type for accessing a `str` value in LLVM. +#[derive(Copy, Clone)] +pub struct StringValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> StringValue<'ctx> { + /// Creates an [`StringValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(val, llvm_usize).is_ok()); + + Self { value: val, llvm_usize, name } + } + + /// Creates an [`StringValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let val = ctx.builder.build_load(ptr, "").map(BasicValueEnum::into_struct_value).unwrap(); + + Self::from_struct_value(val, llvm_usize, name) + } + + fn ptr_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().ptr + } + + /// Returns the pointer to the beginning of the string. + pub fn extract_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.ptr_field().extract_value(ctx, self.value) + } + + fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().len + } + + /// Returns the length of the string. + pub fn extract_len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + self.len_field().extract_value(ctx, self.value) + } +} + +impl<'ctx> ProxyValue<'ctx> for StringValue<'ctx> { + type ABI = StructValue<'ctx>; + type Base = StructValue<'ctx>; + type Type = StringType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_struct_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: StringValue<'ctx>) -> Self { + value.as_base_value() + } +} From 57552fb2f641c34e81537dfbca90aab238a535b2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 17:04:37 +0800 Subject: [PATCH 10/13] [core] codegen: Add Option{Type,Value} --- nac3core/src/codegen/expr.rs | 66 +++------ nac3core/src/codegen/mod.rs | 12 +- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/option.rs | 188 ++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/option.rs | 75 ++++++++++ 6 files changed, 296 insertions(+), 49 deletions(-) create mode 100644 nac3core/src/codegen/types/option.rs create mode 100644 nac3core/src/codegen/values/option.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c398ed97..3f48154c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType, StringType, TupleType}, + types::{ndarray::NDArrayType, ListType, OptionType, RangeType, StringType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -179,34 +179,16 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .into() } SymbolValue::OptionSome(v) => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => codegen_unreachable!(self, "must be option type"), - }; let val = self.gen_symbol_val(generator, v, ty); - let ptr = generator - .gen_var_alloc(self, val.get_type(), Some("default_opt_some")) - .unwrap(); - self.builder.build_store(ptr, val).unwrap(); - ptr.into() - } - SymbolValue::OptionNone => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => codegen_unreachable!(self, "must be option type"), - }; - let actual_ptr_type = - self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); - actual_ptr_type.const_null().into() + OptionType::from_unifier_type(generator, self, ty) + .construct_some_value(generator, self, &val, None) + .as_abi_value(self) + .into() } + SymbolValue::OptionNone => OptionType::from_unifier_type(generator, self, ty) + .construct_empty(generator, self, None) + .as_abi_value(self) + .into(), } } @@ -2333,16 +2315,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( const_val.into() } ExprKind::Name { id, .. } if id == &"none".into() => { - match ( - ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), - ctx.unifier.get_ty(ctx.primitives.option).as_ref(), - ) { - (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) - if *obj_id == *opt_id => + match &*ctx.unifier.get_ty(expr.custom.unwrap()) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => { - ctx.get_llvm_type(generator, *params.iter().next().unwrap().1) - .ptr_type(AddressSpace::default()) - .const_null() + OptionType::from_unifier_type(generator, ctx, expr.custom.unwrap()) + .construct_empty(generator, ctx, None) + .as_abi_value(ctx) .into() } _ => codegen_unreachable!(ctx, "must be option type"), @@ -2827,8 +2806,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; } ValueEnum::Dynamic(BasicValueEnum::PointerValue(ptr)) => { - let not_null = - ctx.builder.build_is_not_null(ptr, "unwrap_not_null").unwrap(); + let option = OptionType::from_pointer_type( + ptr.get_type(), + ctx.get_size_type(), + ) + .map_pointer_value(ptr, None); + let not_null = option.is_some(ctx); ctx.make_assert( generator, not_null, @@ -2837,12 +2820,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( [None, None, None], expr.location, ); - return Ok(Some( - ctx.builder - .build_load(ptr, "unwrap_some_load") - .map(Into::into) - .unwrap(), - )); + return Ok(Some(unsafe { option.load(ctx).into() })); } ValueEnum::Dynamic(_) => { codegen_unreachable!(ctx, "option must be static or ptr") diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 5b6fa215..b9c743c7 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -43,7 +43,9 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, StringType, TupleType}; +use types::{ + ndarray::NDArrayType, ListType, OptionType, ProxyType, RangeType, StringType, TupleType, +}; pub mod builtin_fns; pub mod concrete_type; @@ -538,7 +540,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( if PrimDef::contains_id(*obj_id) { return match &*unifier.get_ty_immutable(ty) { TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => { - get_llvm_type( + let element_type = get_llvm_type( ctx, module, generator, @@ -546,9 +548,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( top_level, type_cache, *params.iter().next().unwrap().1, - ) - .ptr_type(AddressSpace::default()) - .into() + ); + + OptionType::new_with_generator(generator, ctx, &element_type).as_abi_type().into() } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index bceb8040..cbab600b 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -26,12 +26,14 @@ use super::{ {CodeGenContext, CodeGenerator}, }; pub use list::*; +pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; mod list; pub mod ndarray; +mod option; mod range; mod string; pub mod structure; diff --git a/nac3core/src/codegen/types/option.rs b/nac3core/src/codegen/types/option.rs new file mode 100644 index 00000000..6347e5ab --- /dev/null +++ b/nac3core/src/codegen/types/option.rs @@ -0,0 +1,188 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType}, + values::{BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use super::ProxyType; +use crate::{ + codegen::{values::OptionValue, CodeGenContext, CodeGenerator}, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, +}; + +/// Proxy type for an `Option` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct OptionType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> OptionType<'ctx> { + /// Creates an LLVM type corresponding to the expected structure of an `Option`. + #[must_use] + fn llvm_type(element_type: &impl BasicType<'ctx>) -> PointerType<'ctx> { + element_type.ptr_type(AddressSpace::default()) + } + + fn new_impl(element_type: &impl BasicType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_option = Self::llvm_type(element_type); + + Self { ty: llvm_option, llvm_usize } + } + + /// Creates an instance of [`OptionType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(element_type, ctx.get_size_type()) + } + + /// Creates an instance of [`OptionType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + element_type: &impl BasicType<'ctx>, + ) -> Self { + Self::new_impl(element_type, generator.get_size_type(ctx)) + } + + /// Creates an [`OptionType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `element_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `option` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = ctx.get_size_type(); + let llvm_elem_type = ctx.get_llvm_type(generator, elem_type); + + Self::new_impl(&llvm_elem_type, llvm_usize) + } + + /// Creates an [`OptionType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Returns the element type of this `Option` type. + #[must_use] + pub fn element_type(&self) -> BasicTypeEnum<'ctx> { + BasicTypeEnum::try_from(self.ty.get_element_type()).unwrap() + } + + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will be `Some(v)` if [`value` contains a value][Option::is_some], + /// otherwise `none` will be returned. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: Option>, + name: Option<&'ctx str>, + ) -> >::Value { + let ptr = if let Some(v) = value { + let pvar = self.raw_alloca_var(generator, ctx, name); + ctx.builder.build_store(pvar, v).unwrap(); + pvar + } else { + self.ty.const_null() + }; + + self.map_pointer_value(ptr, name) + } + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will always be `none`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.construct(generator, ctx, None, name) + } + + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will be set to `Some(value)`. + #[must_use] + pub fn construct_some_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: &impl BasicValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + self.construct(generator, ctx, Some(value.as_basic_value_enum()), name) + } + + /// Converts an existing value into a [`OptionValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for OptionType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = OptionValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + BasicTypeEnum::try_from(ty.get_element_type()) + .map_err(|()| format!("Expected `ty` to be a BasicTypeEnum, got {ty}"))?; + + Ok(()) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.element_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: OptionType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index cf125fee..7a43ba41 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -3,6 +3,7 @@ use inkwell::{types::IntType, values::BasicValue}; use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; +pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; @@ -10,6 +11,7 @@ pub use tuple::*; mod array; mod list; pub mod ndarray; +mod option; mod range; mod string; pub mod structure; diff --git a/nac3core/src/codegen/values/option.rs b/nac3core/src/codegen/values/option.rs new file mode 100644 index 00000000..7fca60f8 --- /dev/null +++ b/nac3core/src/codegen/values/option.rs @@ -0,0 +1,75 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::OptionType, CodeGenContext}; + +/// Proxy type for accessing a `Option` value in LLVM. +#[derive(Copy, Clone)] +pub struct OptionValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> OptionValue<'ctx> { + /// Creates an [`OptionValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + /// Returns an `i1` indicating if this `Option` instance does not hold a value. + #[must_use] + pub fn is_none(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + ctx.builder.build_is_null(self.value, "").unwrap() + } + + /// Returns an `i1` indicating if this `Option` instance contains a value. + #[must_use] + pub fn is_some(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + ctx.builder.build_is_not_null(self.value, "").unwrap() + } + + /// Loads the value present in this `Option` instance. + /// + /// # Safety + /// + /// The caller must ensure that this `option` value [contains a value][Self::is_some]. + #[must_use] + pub unsafe fn load(&self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { + ctx.builder.build_load(self.value, "").unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for OptionValue<'ctx> { + type ABI = PointerValue<'ctx>; + type Base = PointerValue<'ctx>; + type Type = OptionType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: OptionValue<'ctx>) -> Self { + value.as_base_value() + } +} From 064aa0411f678bd40fe8acdeffd3cad71894e1b5 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 4 Feb 2025 14:43:33 +0800 Subject: [PATCH 11/13] [core] codegen: Add Exception{Type,Value} --- nac3core/src/codegen/expr.rs | 63 +++--- nac3core/src/codegen/stmt.rs | 42 +--- nac3core/src/codegen/types/exception.rs | 257 +++++++++++++++++++++++ nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/values/exception.rs | 188 +++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + 6 files changed, 488 insertions(+), 66 deletions(-) create mode 100644 nac3core/src/codegen/types/exception.rs create mode 100644 nac3core/src/codegen/values/exception.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3f48154c..b3cb3691 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,9 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, OptionType, RangeType, StringType, TupleType}, + types::{ + ndarray::NDArrayType, ExceptionType, ListType, OptionType, RangeType, StringType, TupleType, + }, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -576,42 +578,35 @@ impl<'ctx> CodeGenContext<'ctx, '_> { params: [Option>; 3], loc: Location, ) { + let llvm_i32 = self.ctx.i32_type(); + let llvm_i64 = self.ctx.i64_type(); + let llvm_exn = ExceptionType::get_instance(generator, self); + let zelf = if let Some(exception_val) = self.exception_val { - exception_val + llvm_exn.map_pointer_value(exception_val, Some("exn")) } else { - let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); - let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); - let zelf = generator.gen_var_alloc(self, zelf_ty, Some("exn")).unwrap(); - *self.exception_val.insert(zelf) + let zelf = llvm_exn.alloca_var(generator, self, Some("exn")); + self.exception_val = Some(zelf.as_abi_value(self)); + zelf }; - let int32 = self.ctx.i32_type(); - let zero = int32.const_zero(); - unsafe { - let id_ptr = self.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); - let id = self.resolver.get_string_id(name); - self.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); - let ptr = self - .builder - .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") - .unwrap(); - self.builder.build_store(ptr, msg).unwrap(); - let i64_zero = self.ctx.i64_type().const_zero(); - for (i, attr_ind) in [6, 7, 8].iter().enumerate() { - let ptr = self - .builder - .build_in_bounds_gep( - zelf, - &[zero, int32.const_int(*attr_ind, false)], - "exn.param", - ) - .unwrap(); - let val = params[i].map_or(i64_zero, |v| { - self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap() - }); - self.builder.build_store(ptr, val).unwrap(); - } - } - gen_raise(generator, self, Some(&zelf.into()), loc); + + let id = self.resolver.get_string_id(name); + zelf.store_name(self, llvm_i32.const_int(id as u64, false)); + zelf.store_message(self, msg.into_struct_value()); + zelf.store_params( + self, + params + .iter() + .map(|p| { + p.map_or(llvm_i64.const_zero(), |v| { + self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap() + }) + }) + .collect_array() + .as_ref() + .unwrap(), + ); + gen_raise(generator, self, Some(&zelf), loc); } pub fn make_assert( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 0c1b931a..35ffeead 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -17,10 +17,10 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - types::{ndarray::NDArrayType, RangeType}, + types::{ndarray::NDArrayType, ExceptionType, RangeType}, values::{ ndarray::{RustNDIndex, ScalarOrNDArray}, - ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, + ArrayLikeIndexer, ArraySliceValue, ExceptionValue, ListValue, ProxyValue, }, CodeGenContext, CodeGenerator, }; @@ -1337,43 +1337,19 @@ pub fn exn_constructor<'ctx>( pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - exception: Option<&BasicValueEnum<'ctx>>, + exception: Option<&ExceptionValue<'ctx>>, loc: Location, ) { if let Some(exception) = exception { - unsafe { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let exception = exception.into_pointer_value(); - let file_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr") - .unwrap(); - let filename = ctx.gen_string(generator, loc.file.0); - ctx.builder.build_store(file_ptr, filename).unwrap(); - let row_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr") - .unwrap(); - ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap(); - let col_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr") - .unwrap(); - ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap(); + exception.store_location(generator, ctx, loc); - let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); - let name_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") - .unwrap(); - ctx.builder.build_store(name_ptr, fun_name).unwrap(); - } + let current_fun = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); + let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); + exception.store_func(ctx, fun_name); let raise = get_builtins(generator, ctx, "__nac3_raise"); let exception = *exception; - ctx.build_call_or_invoke(raise, &[exception], "raise"); + ctx.build_call_or_invoke(raise, &[exception.as_abi_value(ctx).into()], "raise"); } else { let resume = get_builtins(generator, ctx, "__nac3_resume"); ctx.build_call_or_invoke(resume, &[], "resume"); @@ -1860,6 +1836,8 @@ pub fn gen_stmt( } else { return Ok(()); }; + let exc = ExceptionType::get_instance(generator, ctx) + .map_pointer_value(exc.into_pointer_value(), None); gen_raise(generator, ctx, Some(&exc), stmt.location); } else { gen_raise(generator, ctx, None, stmt.location); diff --git a/nac3core/src/codegen/types/exception.rs b/nac3core/src/codegen/types/exception.rs new file mode 100644 index 00000000..0a8ec05d --- /dev/null +++ b/nac3core/src/codegen/types/exception.rs @@ -0,0 +1,257 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use super::{ + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, + ProxyType, +}; +use crate::{ + codegen::{values::ExceptionValue, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Proxy type for an `Exception` in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ExceptionType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ExceptionStructFields<'ctx> { + /// The ID of the exception name. + #[value_type(i32_type())] + pub name: StructField<'ctx, IntValue<'ctx>>, + + /// The file where the exception originated from. + #[value_type(get_struct_type("str").unwrap())] + pub file: StructField<'ctx, StructValue<'ctx>>, + + /// The line number where the exception originated from. + #[value_type(i32_type())] + pub line: StructField<'ctx, IntValue<'ctx>>, + + /// The column number where the exception originated from. + #[value_type(i32_type())] + pub col: StructField<'ctx, IntValue<'ctx>>, + + /// The function name where the exception originated from. + #[value_type(get_struct_type("str").unwrap())] + pub func: StructField<'ctx, StructValue<'ctx>>, + + /// The exception message. + #[value_type(get_struct_type("str").unwrap())] + pub message: StructField<'ctx, StructValue<'ctx>>, + + #[value_type(i64_type())] + pub param0: StructField<'ctx, IntValue<'ctx>>, + + #[value_type(i64_type())] + pub param1: StructField<'ctx, IntValue<'ctx>>, + + #[value_type(i64_type())] + pub param2: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ExceptionType<'ctx> { + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ExceptionStructFields<'ctx> { + ExceptionStructFields::new(ctx, llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of an `Exception`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + assert!(ctx.get_struct_type("str").is_some()); + + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_str = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_str, llvm_usize } + } + + /// Creates an instance of [`ExceptionType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ExceptionType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`ExceptionType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.exception.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an [`ExceptionType`] from a [`StructType`] representing an `Exception`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Creates an [`ExceptionType`] from a [`PointerType`] representing an `Exception`. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Returns an instance of [`ExceptionType`] by obtaining the LLVM representation of the builtin + /// `Exception` type. + #[must_use] + pub fn get_instance( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + Self::from_pointer_type( + ctx.get_llvm_type(generator, ctx.primitives.exception).into_pointer_type(), + ctx.get_size_type(), + ) + } + + /// Allocates an instance of [`ExceptionValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ExceptionValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ExceptionValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ExceptionValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ExceptionType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = ExceptionValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields(Self::fields(ctx, llvm_usize), llvm_ty, "exception", &[]) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for ExceptionType<'ctx> { + type StructFields = ExceptionStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ExceptionType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index cbab600b..1dc776b9 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -25,12 +25,14 @@ use super::{ values::{ArraySliceValue, ProxyValue}, {CodeGenContext, CodeGenerator}, }; +pub use exception::*; pub use list::*; pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; +mod exception; mod list; pub mod ndarray; mod option; diff --git a/nac3core/src/codegen/values/exception.rs b/nac3core/src/codegen/values/exception.rs new file mode 100644 index 00000000..0b1796b9 --- /dev/null +++ b/nac3core/src/codegen/values/exception.rs @@ -0,0 +1,188 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue, StructValue}, +}; +use itertools::Itertools; + +use nac3parser::ast::Location; + +use super::{structure::StructProxyValue, ProxyValue, StringValue}; +use crate::codegen::{ + types::{ + structure::{StructField, StructProxyType}, + ExceptionType, + }, + CodeGenContext, CodeGenerator, +}; + +/// Proxy type for accessing an `Exception` value in LLVM. +#[derive(Copy, Clone)] +pub struct ExceptionValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ExceptionValue<'ctx> { + /// Creates an [`ExceptionValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + + /// Creates an [`ExceptionValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn name_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().name + } + + /// Stores the ID of the exception name into this instance. + pub fn store_name(&self, ctx: &CodeGenContext<'ctx, '_>, name: IntValue<'ctx>) { + debug_assert_eq!(name.get_type(), ctx.ctx.i32_type()); + + self.name_field().store(ctx, self.value, name, self.name); + } + + fn file_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().file + } + + /// Stores the file name of the exception source into this instance. + pub fn store_file(&self, ctx: &CodeGenContext<'ctx, '_>, file: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(file, self.llvm_usize).is_ok()); + + self.file_field().store(ctx, self.value, file, self.name); + } + + fn line_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().line + } + + fn col_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().col + } + + /// Stores the [location][Location] of the exception source into this instance. + pub fn store_location( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + location: Location, + ) { + let llvm_i32 = ctx.ctx.i32_type(); + + let filename = ctx.gen_string(generator, location.file.0); + self.store_file(ctx, filename); + + self.line_field().store( + ctx, + self.value, + llvm_i32.const_int(location.row as u64, false), + self.name, + ); + self.col_field().store( + ctx, + self.value, + llvm_i32.const_int(location.column as u64, false), + self.name, + ); + } + + fn func_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().func + } + + /// Stores the function name of the exception source into this instance. + pub fn store_func(&self, ctx: &CodeGenContext<'ctx, '_>, func: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(func, self.llvm_usize).is_ok()); + + self.func_field().store(ctx, self.value, func, self.name); + } + + fn message_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().message + } + + /// Stores the exception message into this instance. + pub fn store_message(&self, ctx: &CodeGenContext<'ctx, '_>, message: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(message, self.llvm_usize).is_ok()); + + self.message_field().store(ctx, self.value, message, self.name); + } + + fn param0_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param0 + } + + fn param1_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param1 + } + + fn param2_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param2 + } + + /// Stores the parameters of the exception into this instance. + /// + /// If the parameter does not exist, pass `i64 0` in the parameter slot. + pub fn store_params(&self, ctx: &CodeGenContext<'ctx, '_>, params: &[IntValue<'ctx>; 3]) { + debug_assert!(params.iter().all(|p| p.get_type() == ctx.ctx.i64_type())); + + [self.param0_field(), self.param1_field(), self.param2_field()] + .into_iter() + .zip_eq(params) + .for_each(|(field, param)| { + field.store(ctx, self.value, *param, self.name); + }); + } +} + +impl<'ctx> ProxyValue<'ctx> for ExceptionValue<'ctx> { + type ABI = PointerValue<'ctx>; + type Base = PointerValue<'ctx>; + type Type = ExceptionType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> StructProxyValue<'ctx> for ExceptionValue<'ctx> {} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ExceptionValue<'ctx>) -> Self { + value.as_base_value() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 7a43ba41..50933333 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -2,6 +2,7 @@ use inkwell::{types::IntType, values::BasicValue}; use super::{types::ProxyType, CodeGenContext}; pub use array::*; +pub use exception::*; pub use list::*; pub use option::*; pub use range::*; @@ -9,6 +10,7 @@ pub use string::*; pub use tuple::*; mod array; +mod exception; mod list; pub mod ndarray; mod option; From 715dc71396675d6fd9951b911a7677e482a78a38 Mon Sep 17 00:00:00 2001 From: occheung Date: Mon, 10 Feb 2025 11:08:24 +0800 Subject: [PATCH 12/13] nac3artiq: acquire special python identifiers --- nac3artiq/demo/min_artiq.py | 11 +- nac3artiq/src/codegen.rs | 240 ++++++++++++++++++++---------------- nac3artiq/src/lib.rs | 31 +++++ 3 files changed, 172 insertions(+), 110 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index fef018b2..cba3ad24 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -16,7 +16,7 @@ __all__ = [ "rpc", "ms", "us", "ns", "print_int32", "print_int64", "Core", "TTLOut", - "parallel", "sequential" + "parallel", "legacy_parallel", "sequential" ] @@ -245,7 +245,7 @@ class Core: embedding = EmbeddingMap() if allow_registration: - compiler.analyze(registered_functions, registered_classes, set()) + compiler.analyze(registered_functions, registered_classes, special_ids, set()) allow_registration = False if hasattr(method, "__self__"): @@ -336,4 +336,11 @@ class UnwrapNoneError(Exception): artiq_builtin = True parallel = KernelContextManager() +legacy_parallel = KernelContextManager() sequential = KernelContextManager() + +special_ids = { + "parallel": id(parallel), + "legacy_parallel": id(legacy_parallel), + "sequential": id(sequential), +} diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d086420c..cc625a02 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -12,7 +12,7 @@ use pyo3::{ PyObject, PyResult, Python, }; -use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use super::{symbol_resolver::InnerResolver, timeline::TimeFns, SpecialPythonId}; use nac3core::{ codegen::{ expr::{create_fn_and_call, destructure_range, gen_call, infer_and_call_function}, @@ -86,6 +86,9 @@ pub struct ArtiqCodeGenerator<'a> { /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// statement, which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, + + /// Specially treated python IDs to identify `with parallel` and `with sequential` blocks. + special_ids: SpecialPythonId, } impl<'a> ArtiqCodeGenerator<'a> { @@ -93,6 +96,7 @@ impl<'a> ArtiqCodeGenerator<'a> { name: String, size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { @@ -103,6 +107,7 @@ impl<'a> ArtiqCodeGenerator<'a> { end: None, timeline, parallel_mode: ParallelMode::None, + special_ids, } } @@ -112,9 +117,10 @@ impl<'a> ArtiqCodeGenerator<'a> { ctx: &Context, target_machine: &TargetMachine, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); - Self::new(name, llvm_usize, timeline) + Self::new(name, llvm_usize, timeline, special_ids) } /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the @@ -260,122 +266,140 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> { // - If there is a end variable, it indicates that we are (indirectly) inside a // parallel block, and we should update the max end value. if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { - if id == &"parallel".into() || id == &"legacy_parallel".into() { - let old_start = self.start.take(); - let old_end = self.end.take(); - let old_parallel_mode = self.parallel_mode; + let resolver = ctx.resolver.clone(); + if let Some(static_value) = + if let Some((_ptr, static_value, _counter)) = ctx.var_assignment.get(id) { + static_value.clone() + } else if let Some(ValueEnum::Static(val)) = + resolver.get_symbol_value(*id, ctx, self) + { + Some(val) + } else { + None + } + { + let python_id = static_value.get_unique_identifier(); + if python_id == self.special_ids.parallel + || python_id == self.special_ids.legacy_parallel + { + let old_start = self.start.take(); + let old_end = self.end.take(); + let old_parallel_mode = self.parallel_mode; - let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + let now = if let Some(old_start) = &old_start { + self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + ctx, + self, + old_start.custom.unwrap(), + )? + } else { + self.timeline.emit_now_mu(ctx) + }; + + // Emulate variable allocation, as we need to use the CodeGenContext + // HashMap to store our variable due to lifetime limitation + // Note: we should be able to store variables directly if generic + // associative type is used by limiting the lifetime of CodeGenerator to + // the LLVM Context. + // The name is guaranteed to be unique as users cannot use this as variable + // name. + self.start = old_start.clone().map_or_else( + || { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: start, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let start = self + .gen_store_target(ctx, &start_expr, Some("start.addr"))? + .unwrap(); + ctx.builder.build_store(start, now).unwrap(); + Ok(Some(start_expr)) as Result<_, String> + }, + |v| Ok(Some(v)), + )?; + let end = format!("with-{}-end", self.name_counter).into(); + let end_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: end, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); + ctx.builder.build_store(end, now).unwrap(); + self.end = Some(end_expr); + self.name_counter += 1; + self.parallel_mode = if python_id == self.special_ids.parallel { + ParallelMode::Deep + } else if python_id == self.special_ids.legacy_parallel { + ParallelMode::Legacy + } else { + unreachable!() + }; + + self.gen_block(ctx, body.iter())?; + + let current = ctx.builder.get_insert_block().unwrap(); + + // if the current block is terminated, move before the terminator + // we want to set the timeline before reaching the terminator + // TODO: This may be unsound if there are multiple exit paths in the + // block... e.g. + // if ...: + // return + // Perhaps we can fix this by using actual with block? + let reset_position = if let Some(terminator) = current.get_terminator() { + ctx.builder.position_before(&terminator); + true + } else { + false + }; + + // set duration + let end_expr = self.end.take().unwrap(); + let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( ctx, self, - old_start.custom.unwrap(), - )? - } else { - self.timeline.emit_now_mu(ctx) - }; + end_expr.custom.unwrap(), + )?; - // Emulate variable allocation, as we need to use the CodeGenContext - // HashMap to store our variable due to lifetime limitation - // Note: we should be able to store variables directly if generic - // associative type is used by limiting the lifetime of CodeGenerator to - // the LLVM Context. - // The name is guaranteed to be unique as users cannot use this as variable - // name. - self.start = old_start.clone().map_or_else( - || { - let start = format!("with-{}-start", self.name_counter).into(); - let start_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: start, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let start = self - .gen_store_target(ctx, &start_expr, Some("start.addr"))? - .unwrap(); - ctx.builder.build_store(start, now).unwrap(); - Ok(Some(start_expr)) as Result<_, String> - }, - |v| Ok(Some(v)), - )?; - let end = format!("with-{}-end", self.name_counter).into(); - let end_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: end, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); - ctx.builder.build_store(end, now).unwrap(); - self.end = Some(end_expr); - self.name_counter += 1; - self.parallel_mode = match id.to_string().as_str() { - "parallel" => ParallelMode::Deep, - "legacy_parallel" => ParallelMode::Legacy, - _ => unreachable!(), - }; + // inside a sequential block + if old_start.is_none() { + self.timeline.emit_at_mu(ctx, end_val); + } - self.gen_block(ctx, body.iter())?; + // inside a parallel block, should update the outer max now_mu + self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - let current = ctx.builder.get_insert_block().unwrap(); + self.parallel_mode = old_parallel_mode; + self.end = old_end; + self.start = old_start; - // if the current block is terminated, move before the terminator - // we want to set the timeline before reaching the terminator - // TODO: This may be unsound if there are multiple exit paths in the - // block... e.g. - // if ...: - // return - // Perhaps we can fix this by using actual with block? - let reset_position = if let Some(terminator) = current.get_terminator() { - ctx.builder.position_before(&terminator); - true - } else { - false - }; + if reset_position { + ctx.builder.position_at_end(current); + } - // set duration - let end_expr = self.end.take().unwrap(); - let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( - ctx, - self, - end_expr.custom.unwrap(), - )?; + return Ok(()); + } else if python_id == self.special_ids.sequential { + // For deep parallel, temporarily take away start to avoid function calls in + // the block from resetting the timeline. + // This does not affect legacy parallel, as the timeline will be reset after + // this block finishes execution. + let start = self.start.take(); + self.gen_block(ctx, body.iter())?; + self.start = start; - // inside a sequential block - if old_start.is_none() { - self.timeline.emit_at_mu(ctx, end_val); + // Reset the timeline when we are exiting the sequential block + // Legacy parallel does not need this, since it will be reset after codegen + // for this statement is completed + if self.parallel_mode == ParallelMode::Deep { + self.timeline_reset_start(ctx)?; + } + + return Ok(()); } - - // inside a parallel block, should update the outer max now_mu - self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - - self.parallel_mode = old_parallel_mode; - self.end = old_end; - self.start = old_start; - - if reset_position { - ctx.builder.position_at_end(current); - } - - return Ok(()); - } else if id == &"sequential".into() { - // For deep parallel, temporarily take away start to avoid function calls in - // the block from resetting the timeline. - // This does not affect legacy parallel, as the timeline will be reset after - // this block finishes execution. - let start = self.start.take(); - self.gen_block(ctx, body.iter())?; - self.start = start; - - // Reset the timeline when we are exiting the sequential block - // Legacy parallel does not need this, since it will be reset after codegen - // for this statement is completed - if self.parallel_mode == ParallelMode::Deep { - self.timeline_reset_start(ctx)?; - } - - return Ok(()); } } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ba6c4fae..d4136a07 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -162,6 +162,13 @@ pub struct PrimitivePythonId { module: u64, } +#[derive(Clone, Default)] +pub struct SpecialPythonId { + parallel: u64, + legacy_parallel: u64, + sequential: u64, +} + type TopLevelComponent = (Stmt, String, PyObject); // TopLevelComposer is unsendable as it holds the unification table, which is @@ -179,6 +186,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, + special_ids: SpecialPythonId, /// LLVM-related options for code generation. llvm_options: CodeGenLLVMOptions, } @@ -797,6 +805,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), )) }) .collect(); @@ -813,6 +822,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); @@ -1192,6 +1202,7 @@ impl Nac3 { string_store: Arc::new(string_store.into()), exception_ids: Arc::default(), deferred_eval_store: DeferredEvaluationStore::new(), + special_ids: Default::default(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, target: isa.get_llvm_target_options(), @@ -1203,6 +1214,7 @@ impl Nac3 { &mut self, functions: &PySet, classes: &PySet, + special_ids: &PyDict, content_modules: &PySet, ) -> PyResult<()> { let (modules, class_ids) = @@ -1236,6 +1248,25 @@ impl Nac3 { for module in modules.into_values() { self.register_module(&module, &class_ids)?; } + + self.special_ids = SpecialPythonId { + parallel: special_ids.get_item("parallel").ok().flatten().unwrap().extract().unwrap(), + legacy_parallel: special_ids + .get_item("legacy_parallel") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + sequential: special_ids + .get_item("sequential") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + }; + Ok(()) } From 82a580c5c6e68e0c8f92ecd44da2a18b1adf7d0a Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 10 Feb 2025 16:53:35 +0800 Subject: [PATCH 13/13] flake: update ARTIQ source used for PGO --- flake.nix | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index 51551c77..243153b2 100644 --- a/flake.nix +++ b/flake.nix @@ -113,8 +113,8 @@ (pkgs.fetchFromGitHub { owner = "m-labs"; repo = "artiq"; - rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6"; - sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak="; + rev = "554b0749ca5985bf4d006c4f29a05e83de0a226d"; + sha256 = "sha256-3eSNHTSlmdzLMcEMIspxqjmjrcQe4aIGqIfRgquUg18="; }) ]; buildInputs = [