From 843ad891642f52c2f2b3dbc847cd1c340572be1b Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 10:10:23 +0800 Subject: [PATCH] [core] codegen: Add Proxy{Type,Value}::as_abi_{type,value} Needed for PtrToOrBasic{Type,Value}. --- nac3artiq/src/codegen.rs | 2 +- nac3artiq/src/symbol_resolver.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 146 +++++++++--------- nac3core/src/codegen/expr.rs | 22 +-- nac3core/src/codegen/irrt/ndarray/array.rs | 4 +- nac3core/src/codegen/irrt/ndarray/basic.rs | 33 ++-- .../src/codegen/irrt/ndarray/broadcast.rs | 2 +- nac3core/src/codegen/irrt/ndarray/indexing.rs | 4 +- nac3core/src/codegen/irrt/ndarray/iter.rs | 8 +- .../src/codegen/irrt/ndarray/transpose.rs | 4 +- nac3core/src/codegen/mod.rs | 8 +- nac3core/src/codegen/numpy.rs | 16 +- nac3core/src/codegen/test.rs | 6 +- nac3core/src/codegen/types/list.rs | 7 +- nac3core/src/codegen/types/mod.rs | 12 +- nac3core/src/codegen/types/ndarray/array.rs | 10 +- .../src/codegen/types/ndarray/broadcast.rs | 7 +- .../src/codegen/types/ndarray/contiguous.rs | 7 +- .../src/codegen/types/ndarray/indexing.rs | 7 +- nac3core/src/codegen/types/ndarray/mod.rs | 7 +- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- nac3core/src/codegen/types/range.rs | 9 +- nac3core/src/codegen/types/tuple.rs | 5 + nac3core/src/codegen/types/utils/slice.rs | 7 +- nac3core/src/codegen/values/list.rs | 7 +- nac3core/src/codegen/values/mod.rs | 14 +- .../src/codegen/values/ndarray/broadcast.rs | 5 + .../src/codegen/values/ndarray/contiguous.rs | 13 +- .../src/codegen/values/ndarray/indexing.rs | 5 + nac3core/src/codegen/values/ndarray/mod.rs | 21 ++- nac3core/src/codegen/values/ndarray/nditer.rs | 9 +- nac3core/src/codegen/values/range.rs | 11 +- nac3core/src/codegen/values/tuple.rs | 5 + nac3core/src/codegen/values/utils/slice.rs | 5 + nac3core/src/toplevel/builtins.rs | 8 +- 35 files changed, 279 insertions(+), 170 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 4c86028f..2cc54387 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -761,7 +761,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - ndarray.as_base_value().into() + ndarray.as_abi_value(ctx).into() } _ => { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 4b398a9b..06a9400c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1146,7 +1146,7 @@ impl InnerResolver { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1316,7 +1316,7 @@ impl InnerResolver { }; let ndarray = llvm_ndarray - .as_base_type() + .as_abi_type() .get_element_type() .into_struct_type() .const_named_struct(&[ @@ -1328,7 +1328,7 @@ impl InnerResolver { ]); let ndarray_global = ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 6cacac45..4974a605 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, IntValue}, + values::{BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -137,7 +137,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -197,7 +197,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -273,7 +273,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -338,7 +338,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -402,7 +402,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -448,7 +448,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -485,7 +485,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -550,7 +550,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -600,7 +600,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -650,7 +650,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -767,7 +767,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1026,7 +1026,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1072,7 +1072,7 @@ where Ok(result) })?; - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1419,7 +1419,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_copysign` builtin function. @@ -1453,7 +1453,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_fmax` builtin function. @@ -1487,7 +1487,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_fmin` builtin function. @@ -1521,7 +1521,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_ldexp` builtin function. @@ -1557,7 +1557,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_hypot` builtin function. @@ -1591,7 +1591,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_nextafter` builtin function. @@ -1625,7 +1625,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - Ok(result.to_basic_value_enum()) + Ok(result.to_basic_value_enum(ctx)) } /// Invokes the `np_linalg_cholesky` linalg function @@ -1653,11 +1653,11 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_cholesky( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_qr` linalg function @@ -1699,20 +1699,20 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_qr( ctx, - x1_c.as_base_value().into(), - q_c.as_base_value().into(), - r_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), + r_c.as_abi_value(ctx).into(), None, ); - let q = q.as_base_value().as_basic_value_enum(); - let r = r.as_base_value().as_basic_value_enum(); + let q = q.as_abi_value(ctx); + let r = r.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( ctx, - [q, r], + [q.into(), r.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_svd` linalg function @@ -1760,19 +1760,19 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_svd( ctx, - x1_c.as_base_value().into(), - u_c.as_base_value().into(), - s_c.as_base_value().into(), - vh_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), + s_c.as_abi_value(ctx).into(), + vh_c.as_abi_value(ctx).into(), None, ); - let u = u.as_base_value().as_basic_value_enum(); - let s = s.as_base_value().as_basic_value_enum(); - let vh = vh.as_base_value().as_basic_value_enum(); + let u = u.as_abi_value(ctx); + let s = s.as_abi_value(ctx); + let vh = vh.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) - .construct_from_objects(ctx, [u, s, vh], None); - Ok(tuple.as_base_value().into()) + .construct_from_objects(ctx, [u.into(), s.into(), vh.into()], None); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_inv` linalg function @@ -1800,12 +1800,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_inv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_pinv` linalg function @@ -1845,12 +1845,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_pinv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_lu` linalg function @@ -1892,20 +1892,20 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let u_c = u.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_lu( ctx, - x1_c.as_base_value().into(), - l_c.as_base_value().into(), - u_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + l_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), None, ); - let l = l.as_base_value().as_basic_value_enum(); - let u = u.as_base_value().as_basic_value_enum(); + let l = l.as_abi_value(ctx); + let u = u.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( ctx, - [l, u], + [l.into(), u.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -1953,13 +1953,13 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_matrix_power( ctx, - x1_c.as_base_value().into(), - x2_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + x2_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_det` linalg function @@ -1993,8 +1993,8 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( let out_c = det.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_det( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); @@ -2035,20 +2035,20 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let z_c = z.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_schur( ctx, - x1_c.as_base_value().into(), - t_c.as_base_value().into(), - z_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + t_c.as_abi_value(ctx).into(), + z_c.as_abi_value(ctx).into(), None, ); - let t = t.as_base_value().as_basic_value_enum(); - let z = z.as_base_value().as_basic_value_enum(); + let t = t.as_abi_value(ctx); + let z = z.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( ctx, - [t, z], + [t.into(), z.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2083,18 +2083,18 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let q_c = q.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_hessenberg( ctx, - x1_c.as_base_value().into(), - h_c.as_base_value().into(), - q_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + h_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), None, ); - let h = h.as_base_value().as_basic_value_enum(); - let q = q.as_base_value().as_basic_value_enum(); + let h = h.as_abi_value(ctx); + let q = q.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( ctx, - [h, q], + [h.into(), q.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4da0ef33..11b31350 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1307,7 +1307,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( emit_cont_bb(ctx, list); - Ok(Some(list.as_base_value().into())) + Ok(Some(list.as_abi_value(ctx).into())) } /// Generates LLVM IR for a binary operator expression using the [`Type`] and @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } _ => todo!("Operator not supported"), @@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let result = left .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) .split_unsized(generator, ctx); - Ok(Some(result.to_basic_value_enum().into())) + Ok(Some(result.to_basic_value_enum(ctx).into())) } else { // For other operations, they are all elementwise operations. @@ -1601,7 +1601,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(result) }) .unwrap(); - Ok(Some(result.as_base_value().into())) + Ok(Some(result.as_abi_value(ctx).into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1796,7 +1796,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - mapped_ndarray.as_base_value().into() + mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() })) @@ -1883,7 +1883,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - return Ok(Some(result_ndarray.as_base_value().into())); + return Ok(Some(result_ndarray.as_abi_value(ctx).into())); } } @@ -2493,7 +2493,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ); ctx.builder.build_store(elem_ptr, *v).unwrap(); } - arr_str_ptr.as_base_value().into() + arr_str_ptr.as_abi_value(ctx).into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -2988,7 +2988,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.as_base_value().into() + res_array_ret.as_abi_value(ctx).into() } else { let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { @@ -3049,8 +3049,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let result = ndarray .index(generator, ctx, &indices) .split_unsized(generator, ctx) - .to_basic_value_enum(); - return Ok(Some(ValueEnum::Dynamic(result))); + .to_basic_value_enum(ctx); + return Ok(Some(result.into())); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs index 5e9c0f0b..63a2ab00 100644 --- a/nac3core/src/codegen/irrt/ndarray/array.rs +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -36,7 +36,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato ctx, &name, None, - &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + &[list.as_abi_value(ctx).into(), ndims.into(), shape.base_ptr(ctx, generator).into()], None, None, ); @@ -65,7 +65,7 @@ pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( ctx, &name, None, - &[list.as_base_value().into(), ndarray.as_base_value().into()], + &[list.as_abi_value(ctx).into(), ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index aa792b15..5f291c86 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -93,7 +93,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); @@ -101,7 +101,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("size"), None, ) @@ -118,7 +118,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); @@ -126,7 +126,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("nbytes"), None, ) @@ -143,7 +143,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); @@ -151,7 +151,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("len"), None, ) @@ -167,7 +167,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); @@ -175,7 +175,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("is_c_contiguous"), None, ) @@ -194,7 +194,7 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!(index.get_type(), llvm_usize); @@ -204,7 +204,10 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx, &name, Some(llvm_pi8.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())], + &[ + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), + (llvm_usize.into(), index.into()), + ], Some("pelement"), None, ) @@ -227,7 +230,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!( BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), @@ -241,7 +244,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized &name, Some(llvm_pi8.into()), &[ - (llvm_ndarray.into(), ndarray.as_base_value().into()), + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), @@ -258,7 +261,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); @@ -266,7 +269,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx, &name, None, - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], None, None, ); @@ -288,7 +291,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index a7d40a57..59b0e4cd 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -30,7 +30,7 @@ pub fn call_nac3_ndarray_broadcast_to<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index df5b27de..0d5d920e 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -25,8 +25,8 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( &[ indices.size(ctx, generator).into(), indices.base_ptr(ctx, generator).into(), - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index ad90178c..e4424df0 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -40,8 +40,8 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - (iter.get_type().as_base_type().into(), iter.as_base_value().into()), - (ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()), + (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), + (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], None, @@ -63,7 +63,7 @@ pub fn call_nac3_nditer_has_element<'ctx>( ctx, &name, Some(ctx.ctx.bool_type().into()), - &[iter.as_base_value().into()], + &[iter.as_abi_value(ctx).into()], None, None, ) @@ -77,5 +77,5 @@ pub fn call_nac3_nditer_has_element<'ctx>( pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); - infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); + infer_and_call_function(ctx, &name, None, &[iter.as_abi_value(ctx).into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs index 6d152dd1..331611fa 100644 --- a/nac3core/src/codegen/irrt/ndarray/transpose.rs +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -34,8 +34,8 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { axes.base_ptr(ctx, generator) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 73a28b7a..a188d1c3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -562,7 +562,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_abi_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -572,7 +572,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_abi_type().into() } _ => unreachable!( @@ -626,7 +626,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_abi_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), @@ -800,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new_with_generator(generator, context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3cdd1ef3..2eec88da 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -44,7 +44,7 @@ pub fn gen_ndarray_empty<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.zeros`. @@ -69,7 +69,7 @@ pub fn gen_ndarray_zeros<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.ones`. @@ -94,7 +94,7 @@ pub fn gen_ndarray_ones<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.full`. @@ -127,7 +127,7 @@ pub fn gen_ndarray_full<'ctx>( fill_value_arg, None, ); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } pub fn gen_ndarray_array<'ctx>( @@ -166,7 +166,7 @@ pub fn gen_ndarray_array<'ctx>( .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) .atleast_nd(generator, context, ndims); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.eye`. @@ -225,7 +225,7 @@ pub fn gen_ndarray_eye<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.identity`. @@ -253,7 +253,7 @@ pub fn gen_ndarray_identity<'ctx>( .unwrap(); let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.copy`. @@ -274,7 +274,7 @@ pub fn gen_ndarray_copy<'ctx>( let this = NDArrayType::from_unifier_type(generator, context, this_ty) .map_value(this_arg.into_pointer_value(), None); let ndarray = this.make_copy(generator, context); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.fill`. diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ecc0ba96..15c4654a 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -447,7 +447,7 @@ fn test_classes_list_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); - assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); + assert!(ListType::is_representable(llvm_list.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -458,7 +458,7 @@ fn test_classes_range_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_range = RangeType::new_with_generator(&generator, &ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type(), llvm_usize).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -470,5 +470,5 @@ fn test_classes_ndarray_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); - assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); + assert!(NDArrayType::is_representable(llvm_ndarray.as_abi_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 60015b8c..f99ad5cb 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -305,6 +305,7 @@ impl<'ctx> ListType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; @@ -344,12 +345,16 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 1c66a7d1..7a5e35e6 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -38,8 +38,10 @@ pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a - /// [LLVM pointer type][PointerType] for any non-primitive types. + /// The ABI type of which values of this type possess. + type ABI: BasicType<'ctx>; + + /// The LLVM type of which values of this type possess. type Base: BasicType<'ctx>; /// The type of values represented by this type. @@ -115,4 +117,10 @@ pub trait ProxyType<'ctx>: Into { /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; + + /// Returns this proxy as its ABI type, i.e. the expected type representation if a value of this + /// [`ProxyType`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_type(&self) -> Self::ABI; } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 70611127..9630ec15 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -151,7 +151,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, |generator, ctx| { let ndarray = self.construct_numpy_array_from_list_copy_none_impl( @@ -160,7 +160,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() @@ -189,11 +189,11 @@ impl<'ctx> NDArrayType<'ctx> { |_generator, _ctx| Ok(copy), |generator, ctx| { let ndarray = ndarray.make_copy(generator, ctx); // Force copy - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, - |_generator, _ctx| { + |_generator, ctx| { // No need to copy. Return `ndarray` itself. - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index af1a26fa..40847ce2 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -127,6 +127,7 @@ impl<'ctx> ShapeEntryType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ShapeEntryValue<'ctx>; @@ -160,12 +161,16 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 1987ab6d..40311a57 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -189,6 +189,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; @@ -230,12 +231,16 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 8e15c903..ec214ceb 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -158,6 +158,7 @@ impl<'ctx> NDIndexType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; @@ -188,12 +189,16 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 1743fe2b..a79a1f30 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -427,6 +427,7 @@ impl<'ctx> NDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; @@ -458,12 +459,16 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 6246eef2..ba21a7ea 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -185,6 +185,7 @@ impl<'ctx> NDIterType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; @@ -216,12 +217,16 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index 158152bf..b6f15c70 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -72,7 +72,7 @@ impl<'ctx> RangeType<'ctx> { /// Returns the type of all fields of this `range` type. #[must_use] pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + self.as_abi_type().get_element_type().into_array_type().get_element_type().into_int_type() } /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. @@ -120,6 +120,7 @@ impl<'ctx> RangeType<'ctx> { } impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = RangeValue<'ctx>; @@ -163,12 +164,16 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index d05b7f26..29e93233 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -157,6 +157,7 @@ impl<'ctx> TupleType<'ctx> { } impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type ABI = StructType<'ctx>; type Base = StructType<'ctx>; type Value = TupleValue<'ctx>; @@ -182,6 +183,10 @@ impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for StructType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index b7fafefa..e482ed5b 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -174,6 +174,7 @@ impl<'ctx> SliceType<'ctx> { } impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = SliceValue<'ctx>; @@ -229,12 +230,16 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 075f7f64..8b2b6cb2 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -110,7 +110,7 @@ impl<'ctx> ListValue<'ctx> { let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( - ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_abi_type(), "").unwrap(), self.llvm_usize, self.name, ) @@ -118,6 +118,7 @@ impl<'ctx> ListValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ListType<'ctx>; @@ -128,6 +129,10 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 9a246356..90f327e0 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,6 +1,6 @@ use inkwell::{types::IntType, values::BasicValue}; -use super::types::ProxyType; +use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; @@ -16,8 +16,10 @@ pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the - /// [LLVM pointer type][PointerValue]. + /// The ABI type of LLVM values represented by this instance. + type ABI: BasicValue<'ctx>; + + /// The type of LLVM values represented by this instance. type Base: BasicValue<'ctx>; /// The type of this value. @@ -33,4 +35,10 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; + + /// Returns this proxy as its ABI value, i.e. the expected value representation if a value + /// represented by this [`ProxyValue`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> Self::ABI; } diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index acbd2997..883b4613 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -58,6 +58,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ShapeEntryType<'ctx>; @@ -68,6 +69,10 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 65e80258..a23be229 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_base_value(), value, self.name); + self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -49,7 +49,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_base_value(), value, self.name); + self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -61,7 +61,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_base_value(), value, self.name); + self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -70,6 +70,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ContiguousNDArrayType<'ctx>; @@ -84,6 +85,10 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -124,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); + let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3b7b8f10..00846713 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -68,6 +68,7 @@ impl<'ctx> NDIndexValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIndexType<'ctx>; @@ -78,6 +79,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index cba35ad2..1c105d64 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); + self.shape_field(ctx).set(ctx, self.as_abi_value(ctx), dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); + self.strides_field(ctx).set(ctx, self.as_abi_value(ctx), strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name); + self.data_field(ctx).set(ctx, self.as_abi_value(ctx), data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -462,6 +462,7 @@ impl<'ctx> NDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDArrayType<'ctx>; @@ -477,6 +478,10 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -503,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.shape_field(ctx).get(ctx, self.0.as_abi_value(ctx), self.0.name) } fn size( @@ -601,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.strides_field(ctx).get(ctx, self.0.as_abi_value(ctx), self.0.name) } fn size( @@ -699,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.data_field(ctx).get(ctx, self.0.as_abi_value(ctx), self.0.name) } fn size( @@ -963,10 +968,10 @@ impl<'ctx> ScalarOrNDArray<'ctx> { /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] - pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { + pub fn to_basic_value_enum(self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { match self { ScalarOrNDArray::Scalar(scalar) => scalar, - ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), + ScalarOrNDArray::NDArray(ndarray) => ndarray.as_abi_value(ctx).into(), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 5479b929..3fdd0a8c 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).get(ctx, self.as_base_value(), self.name); + let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).get(ctx, self.as_base_value(), self.name) + self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. @@ -105,6 +105,7 @@ impl<'ctx> NDIterValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIterType<'ctx>; @@ -115,6 +116,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index b1a5806a..20bdba79 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -34,7 +34,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], var_name.as_str(), ) @@ -49,7 +49,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], var_name.as_str(), ) @@ -64,7 +64,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], var_name.as_str(), ) @@ -137,6 +137,7 @@ impl<'ctx> RangeValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = RangeType<'ctx>; @@ -147,6 +148,10 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 4558f18c..08b2b8be 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -57,6 +57,7 @@ impl<'ctx> TupleValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type ABI = StructValue<'ctx>; type Base = StructValue<'ctx>; type Type = TupleType<'ctx>; @@ -67,6 +68,10 @@ impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for StructValue<'ctx> { diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index df9e4de5..21453f4d 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -150,6 +150,7 @@ impl<'ctx> SliceValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = SliceType<'ctx>; @@ -160,6 +161,10 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1c3b0854..165f64a8 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -664,7 +664,7 @@ impl<'a> BuiltinBuilder<'a> { zelf.store_end(ctx, stop); zelf.store_step(ctx, step); - Ok(Some(zelf.as_base_value().into())) + Ok(Some(zelf.as_abi_value(ctx).into())) }, )))), loc: None, @@ -1320,7 +1320,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(result_tuple.as_base_value().into())) + Ok(Some(result_tuple.as_abi_value(ctx).into())) }), ) } @@ -1356,7 +1356,7 @@ impl<'a> BuiltinBuilder<'a> { .map_value(arg_val.into_pointer_value(), None); let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument - Ok(Some(ndarray.as_base_value().into())) + Ok(Some(ndarray.as_abi_value(ctx).into())) }), ), @@ -1410,7 +1410,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + Ok(Some(new_ndarray.as_abi_value(ctx).as_basic_value_enum())) }), ) }