From fd78f7a0e8a10729725d0c72478ce096605f5edd Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 14 Aug 2024 15:56:59 +0800 Subject: [PATCH] WIP: core/ndstrides: done --- nac3core/src/codegen/builtin_fns.rs | 513 +----------------- nac3core/src/codegen/expr.rs | 15 +- nac3core/src/codegen/irrt/mod.rs | 4 +- nac3core/src/codegen/irrt/util.rs | 107 ---- nac3core/src/codegen/model/core.rs | 1 + nac3core/src/codegen/model/float.rs | 88 +++ nac3core/src/codegen/model/function.rs | 125 +++++ nac3core/src/codegen/model/int.rs | 1 - nac3core/src/codegen/model/mod.rs | 3 + nac3core/src/codegen/model/ptr.rs | 12 +- nac3core/src/codegen/numpy_new.rs | 21 +- nac3core/src/codegen/object/list.rs | 2 - nac3core/src/codegen/object/ndarray/array.rs | 17 +- .../src/codegen/object/ndarray/broadcast.rs | 2 +- .../src/codegen/object/ndarray/functions.rs | 73 ++- .../src/codegen/object/ndarray/indexing.rs | 5 +- .../src/codegen/object/ndarray/mapping.rs | 4 +- nac3core/src/codegen/object/ndarray/mod.rs | 88 +-- .../src/codegen/object/ndarray/nalgebra.rs | 52 ++ .../src/codegen/object/ndarray/product.rs | 2 +- nac3core/src/codegen/object/ndarray/scalar.rs | 13 +- .../src/codegen/object/ndarray/shape_util.rs | 7 +- nac3core/src/codegen/object/tuple.rs | 27 +- nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/toplevel/builtins.rs | 136 ++++- nac3core/src/toplevel/numpy.rs | 3 +- 26 files changed, 527 insertions(+), 796 deletions(-) delete mode 100644 nac3core/src/codegen/irrt/util.rs create mode 100644 nac3core/src/codegen/model/float.rs create mode 100644 nac3core/src/codegen/model/function.rs diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index d21ebff1..44f104ff 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,17 +1,12 @@ -use inkwell::types::BasicTypeEnum; -use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue}; +use inkwell::values::{BasicValueEnum, IntValue}; use inkwell::IntPredicate; use itertools::Itertools; -use crate::codegen::classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, -}; +use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, RangeValue, TypedArrayLikeAccessor}; use crate::codegen::expr::destructure_range; use crate::codegen::irrt::calculate_len_for_slice_range; -use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; +use crate::codegen::{CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Type, TypeEnum}; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. @@ -84,505 +79,3 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( } }) } - -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: Vec>, -) -> PointerValue<'ctx> { - let field_ty = - out_matrices.iter().map(BasicValueEnum::get_type).collect::>(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.into_iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, v).unwrap(); - } - } - out_ptr -} - -/// Invokes the `np_linalg_cholesky` linalg function -pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_cholesky"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `np_linalg_qr` linalg function -pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_qr"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unimplemented!("{FN_NAME} operates on float type NdArrays only"); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); - - let out_ptr = build_output_struct(ctx, vec![out_q, out_r]); - - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `np_linalg_svd` linalg function -pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_svd"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); - - let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `np_linalg_inv` linalg function -pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_inv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_inv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `np_linalg_pinv` linalg function -pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_pinv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_pinv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `sp_linalg_lu` linalg function -pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "sp_linalg_lu"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); - - let out_ptr = build_output_struct(ctx, vec![out_l, out_u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `np_linalg_matrix_power` linalg function -pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - todo!(); - - /* - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); - - let llvm_usize = generator.get_size_type(ctx.ctx); - if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - // Changing second parameter to a `NDArray` for uniformity in function call - let n2_array = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - unsafe { - n2_array.data().set_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - n2.as_basic_value_enum(), - ); - }; - let n2_array = n2_array.as_base_value().as_basic_value_enum(); - - let outdim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let outdim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) - } - */ -} - -/// Invokes the `np_linalg_det` linalg function -pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; - - let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(_) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - // Changing second parameter to a `NDArray` for uniformity in function call - let out = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); - let res = - unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - Ok(res) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `sp_linalg_schur` linalg function -pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "sp_linalg_schur"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); - - let out_ptr = build_output_struct(ctx, vec![out_t, out_z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} - -/// Invokes the `sp_linalg_hessenberg` linalg function -pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "sp_linalg_hessenberg"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); - - let out_ptr = build_output_struct(ctx, vec![out_h, out_q]); - Ok(ctx - .builder - .build_load(out_ptr, "Hessenberg_decomposition_result") - .map(Into::into) - .unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } -} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 64f39ac7..50befb48 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,10 +32,7 @@ use crate::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{AnyType, BasicType, BasicTypeEnum}, - values::{ - BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, - StructValue, - }, + values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::{chain, izip, Either, Itertools}; @@ -314,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.raise_exn( generator, "0:NotImplementedError", - msg.into(), + msg, [None, None, None], self.current_loc, ); @@ -639,7 +636,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap())); let err_msg = self.gen_string(generator, err_msg); - self.make_assert_impl(generator, cond, err_name, err_msg.into(), params, loc); + self.make_assert_impl(generator, cond, err_name, err_msg, params, loc); } pub fn make_assert_impl( @@ -1574,9 +1571,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( gen_binop_expr_with_values( generator, ctx, - (&Some(left.dtype), left.instance), + (&Some(left.dtype), left.value), op, - (&Some(right.dtype), right.instance), + (&Some(right.dtype), right.value), ctx.current_loc, )? .unwrap() @@ -2689,7 +2686,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.raise_exn( generator, "0:UnwrapNoneError", - err_msg.into(), + err_msg, [None, None, None], ctx.current_loc, ); diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 9e8434f1..2dd71ef1 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -2,7 +2,6 @@ use crate::symbol_resolver::SymbolResolver; use crate::typecheck::typedef::Type; mod test; -pub mod util; use super::model::*; use super::object::ndarray::broadcast::ShapeEntry; @@ -17,6 +16,7 @@ use super::{ }; use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::stmt::gen_for_callback_incrementing; +use function::{get_sizet_dependent_function_name, CallFunction}; use inkwell::values::BasicValue; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -29,8 +29,6 @@ use inkwell::{ }; use itertools::Either; use nac3parser::ast::Expr; -use util::function::CallFunction; -use util::get_sizet_dependent_function_name; #[must_use] pub fn load_irrt(ctx: &Context) -> Module { diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs deleted file mode 100644 index 823af277..00000000 --- a/nac3core/src/codegen/irrt/util.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate::codegen::{CodeGenContext, CodeGenerator}; - -// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". -// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". -#[must_use] -pub fn get_sizet_dependent_function_name( - generator: &mut G, - ctx: &CodeGenContext<'_, '_>, - name: &str, -) -> String { - let mut name = name.to_owned(); - match generator.get_size_type(ctx.ctx).get_bit_width() { - 32 => {} - 64 => name.push_str("64"), - bit_width => { - panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") - } - } - name -} - -pub mod function { - use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; - use inkwell::{ - types::{BasicMetadataTypeEnum, BasicType, FunctionType}, - values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue}, - }; - use itertools::Itertools; - - #[derive(Debug, Clone, Copy)] - struct Arg<'ctx> { - ty: BasicMetadataTypeEnum<'ctx>, - val: BasicMetadataValueEnum<'ctx>, - } - - /// Helper structure to reduce IRRT Inkwell function call boilerplate - pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> { - generator: &'d mut G, - ctx: &'b CodeGenContext<'ctx, 'a>, - /// Function name - name: &'c str, - /// Call arguments - args: Vec>, - } - - impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> { - pub fn begin( - generator: &'d mut G, - ctx: &'b CodeGenContext<'ctx, 'a>, - name: &'c str, - ) -> Self { - CallFunction { generator, ctx, name, args: Vec::new() } - } - - /// Push a call argument to the function call. - #[allow(clippy::needless_pass_by_value)] - #[must_use] - pub fn arg>(mut self, arg: Instance<'ctx, M>) -> Self { - let arg = Arg { - ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(), - val: arg.value.as_basic_value_enum().into(), - }; - self.args.push(arg); - self - } - - /// Call the function and expect the function to return a value of type of `return_model`. - #[must_use] - pub fn returning>(self, name: &str, return_model: M) -> Instance<'ctx, M> { - let ret_ty = return_model.get_type(self.generator, self.ctx.ctx); - - let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name); - let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work - let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work - ret - } - - /// Like [`CallFunction::returning_`] but `return_model` is automatically inferred. - #[must_use] - pub fn returning_auto + Default>(self, name: &str) -> Instance<'ctx, M> { - self.returning(name, M::default()) - } - - /// Call the function and expect the function to return a void-type. - pub fn returning_void(self) { - let ret_ty = self.ctx.ctx.void_type(); - - let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), ""); - } - - fn get_function(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx> - where - F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>, - { - // Get the LLVM function, declare the function if it doesn't exist - it will be defined by other - // components of NAC3. - let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| { - let tys = self.args.iter().map(|arg| arg.ty).collect_vec(); - let fn_type = make_fn_type(&tys); - self.ctx.module.add_function(self.name, fn_type, None) - }); - - let vals = self.args.iter().map(|arg| arg.val).collect_vec(); - self.ctx.builder.build_call(func, &vals, return_value_name).unwrap() - } - } -} diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 08250a64..6a85664b 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -21,6 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy { type Type: BasicType<'ctx>; /// Return the [`BasicType`] of this model. + #[must_use] fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; /// Check if a [`BasicType`] is the same type of this model. diff --git a/nac3core/src/codegen/model/float.rs b/nac3core/src/codegen/model/float.rs new file mode 100644 index 00000000..c2549149 --- /dev/null +++ b/nac3core/src/codegen/model/float.rs @@ -0,0 +1,88 @@ +use std::fmt; + +use inkwell::{context::Context, types::FloatType, values::FloatValue}; + +use crate::codegen::CodeGenerator; + +use super::*; + +pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy { + fn get_float_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Float32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Float64; + +impl<'ctx> FloatKind<'ctx> for Float32 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f32_type() + } +} + +impl<'ctx> FloatKind<'ctx> for Float64 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f64_type() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AnyFloat<'ctx>(FloatType<'ctx>); + +impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> { + fn get_float_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> FloatType<'ctx> { + self.0 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct FloatModel(pub N); +pub type Float<'ctx, N> = Instance<'ctx, FloatModel>; + +impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for FloatModel { + type Value = FloatValue<'ctx>; + type Type = FloatType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_float_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = FloatType::try_from(ty) else { + return Err(ModelError(format!("Expecting FloatType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_float_type(generator, ctx); + + // TODO: Inkwell does not have get_bit_width for FloatType? + // TODO: Quick hack for now, but does this actually work? + if ty != exp_ty { + return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}"))); + } + + Ok(()) + } +} diff --git a/nac3core/src/codegen/model/function.rs b/nac3core/src/codegen/model/function.rs new file mode 100644 index 00000000..7646c2bf --- /dev/null +++ b/nac3core/src/codegen/model/function.rs @@ -0,0 +1,125 @@ +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + types::{BasicMetadataTypeEnum, BasicType, FunctionType}, + values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue}, +}; +use itertools::Itertools; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". +// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". +#[must_use] +pub fn get_sizet_dependent_function_name( + generator: &mut G, + ctx: &CodeGenContext<'_, '_>, + name: &str, +) -> String { + let mut name = name.to_owned(); + match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => {} + 64 => name.push_str("64"), + bit_width => { + panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") + } + } + name +} + +#[derive(Debug, Clone, Copy)] +struct Arg<'ctx> { + ty: BasicMetadataTypeEnum<'ctx>, + val: BasicMetadataValueEnum<'ctx>, +} + +/// A structure to construct & call an LLVM function. +/// +/// This is a helper to reduce IRRT Inkwell function call boilerplate +// TODO: Remove the lifetimes somehow? There is 4 of them. +pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> { + generator: &'d mut G, + ctx: &'b CodeGenContext<'ctx, 'a>, + /// Function name + name: &'c str, + /// Call arguments + args: Vec>, + /// LLVM function Attributes + attrs: Vec<&'static str>, +} + +impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> { + pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self { + CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() } + } + + /// Push a list of LLVM function attributes to the function declaration. + #[must_use] + pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self { + self.attrs = attrs; + self + } + + /// Push a call argument to the function call. + #[allow(clippy::needless_pass_by_value)] + #[must_use] + pub fn arg>(mut self, arg: Instance<'ctx, M>) -> Self { + let arg = Arg { + ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(), + val: arg.value.as_basic_value_enum().into(), + }; + self.args.push(arg); + self + } + + /// Call the function and expect the function to return a value of type of `return_model`. + #[must_use] + pub fn returning>(self, name: &str, return_model: M) -> Instance<'ctx, M> { + let ret_ty = return_model.get_type(self.generator, self.ctx.ctx); + + let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name); + let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work + let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work + ret + } + + /// Like [`CallFunction::returning_`] but `return_model` is automatically inferred. + #[must_use] + pub fn returning_auto + Default>(self, name: &str) -> Instance<'ctx, M> { + self.returning(name, M::default()) + } + + /// Call the function and expect the function to return a void-type. + pub fn returning_void(self) { + let ret_ty = self.ctx.ctx.void_type(); + + let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), ""); + } + + fn get_function(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx> + where + F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>, + { + // Get the LLVM function. + let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| { + // Declare the function if it doesn't exist. + let tys = self.args.iter().map(|arg| arg.ty).collect_vec(); + + let func_type = make_fn_type(&tys); + let func = self.ctx.module.add_function(self.name, func_type, None); + + for attr in &self.attrs { + func.add_attribute( + AttributeLoc::Function, + self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + let vals = self.args.iter().map(|arg| arg.val).collect_vec(); + self.ctx.builder.build_call(func, &vals, return_value_name).unwrap() + } +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs index cf51f673..732d5d88 100644 --- a/nac3core/src/codegen/model/int.rs +++ b/nac3core/src/codegen/model/int.rs @@ -96,7 +96,6 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel { type Value = IntValue<'ctx>; type Type = IntType<'ctx>; - #[must_use] fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { self.0.get_int_type(generator, ctx) } diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs index b73a1de9..0c07224e 100644 --- a/nac3core/src/codegen/model/mod.rs +++ b/nac3core/src/codegen/model/mod.rs @@ -1,5 +1,7 @@ mod any; mod core; +mod float; +pub mod function; mod int; mod ptr; mod structure; @@ -7,6 +9,7 @@ pub mod util; pub use any::*; pub use core::*; +pub use float::*; pub use int::*; pub use ptr::*; pub use structure::*; diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs index d27b0126..02164d1e 100644 --- a/nac3core/src/codegen/model/ptr.rs +++ b/nac3core/src/codegen/model/ptr.rs @@ -89,15 +89,17 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, Element> { self.model.check_value(generator, ctx.ctx, new_ptr).unwrap() } - // Load the `i`-th element (0-based) on the array with [`inkwell::builder::Builder::build_in_bounds_gep`]. - pub fn ix( + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset. + #[must_use] + pub fn offset_const( &self, generator: &mut G, ctx: &CodeGenContext<'ctx, '_>, - i: IntValue<'ctx>, + offset: u64, name: &str, - ) -> Instance<'ctx, Element> { - self.offset(generator, ctx, i, name).load(generator, ctx, name) + ) -> Ptr<'ctx, Element> { + let offset = ctx.ctx.i32_type().const_int(offset, false); + self.offset(generator, ctx, offset, name) } /// Load the value with [`inkwell::builder::Builder::build_load`]. diff --git a/nac3core/src/codegen/numpy_new.rs b/nac3core/src/codegen/numpy_new.rs index b805cb4c..c0691574 100644 --- a/nac3core/src/codegen/numpy_new.rs +++ b/nac3core/src/codegen/numpy_new.rs @@ -94,8 +94,7 @@ where { let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape); - let ndarray = - NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray"); + let ndarray = NDArrayObject::alloca_ndarray_type(generator, ctx, ndarray_ty, "ndarray"); // Validate `shape` let ndims = ndarray.get_ndims(generator, ctx.ctx); @@ -321,7 +320,7 @@ pub fn gen_ndarray_arange<'ctx>( let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim"); // Allocate the resulting ndarray - let ndarray = NDArrayObject::alloca_uninitialized( + let ndarray = NDArrayObject::alloca( generator, ctx, ctx.primitives.float, @@ -385,20 +384,17 @@ pub fn gen_ndarray_shape<'ctx>( let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; - // Define models - let sizet_model = IntModel(SizeT); - // Process ndarray let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); let mut objects = Vec::with_capacity(ndarray.ndims as usize); for i in 0..ndarray.ndims { - let i = sizet_model.constant(generator, ctx.ctx, i); let dim = ndarray .instance .get(generator, ctx, |f| f.shape, "") - .ix(generator, ctx, i.value, "dim"); + .offset_const(generator, ctx, i, "") + .load(generator, ctx, "dim"); let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT objects @@ -427,20 +423,17 @@ pub fn gen_ndarray_strides<'ctx>( let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; - // Define models - let sizet_model = IntModel(SizeT); - // Process ndarray let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); let mut objects = Vec::with_capacity(ndarray.ndims as usize); for i in 0..ndarray.ndims { - let i = sizet_model.constant(generator, ctx.ctx, i); let dim = ndarray .instance .get(generator, ctx, |f| f.strides, "") - .ix(generator, ctx, i.value, "dim"); + .offset_const(generator, ctx, i, "") + .load(generator, ctx, "dim"); let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT objects @@ -524,7 +517,7 @@ pub fn gen_ndarray_array<'ctx>( // We simply make the output ndarray's ndims correct with `atleast_nd`. let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); - let output_ndims = extract_ndims(&mut ctx.unifier, ndims); + let output_ndims = extract_ndims(&ctx.unifier, ndims); let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8 let copy = copy.truncate(generator, ctx, Bool, "copy_bool"); diff --git a/nac3core/src/codegen/object/list.rs b/nac3core/src/codegen/object/list.rs index 2b41c7b8..05a49f46 100644 --- a/nac3core/src/codegen/object/list.rs +++ b/nac3core/src/codegen/object/list.rs @@ -1,5 +1,3 @@ -use inkwell::values::BasicValue; - use crate::{ codegen::{model::*, structure::List, CodeGenContext, CodeGenerator}, typecheck::typedef::{iter_type_vars, Type, TypeEnum}, diff --git a/nac3core/src/codegen/object/ndarray/array.rs b/nac3core/src/codegen/object/ndarray/array.rs index 720b9771..8aaa0324 100644 --- a/nac3core/src/codegen/object/ndarray/array.rs +++ b/nac3core/src/codegen/object/ndarray/array.rs @@ -41,13 +41,7 @@ impl<'ctx> NDArrayObject<'ctx> { let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape"); call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape); - let ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - dtype, - ndims_int, - "ndarray_from_list", - ); + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int, "ndarray_from_list"); ndarray.copy_shape_from_array(generator, ctx, shape); ndarray.create_data(generator, ctx); @@ -73,13 +67,8 @@ impl<'ctx> NDArrayObject<'ctx> { let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); if ndims == 1 { // `list` is not nested, does not need to copy. - let ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - dtype, - 1, - "ndarray_from_list_no_copy", - ); + let ndarray = + NDArrayObject::alloca(generator, ctx, dtype, 1, "ndarray_from_list_no_copy"); // Set data let data = list.get_opaque_items_ptr(generator, ctx); diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs index ff4d92a4..1b82e0b2 100644 --- a/nac3core/src/codegen/object/ndarray/broadcast.rs +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> { target_ndims: u64, target_shape: Ptr<'ctx, IntModel>, ) -> Self { - let broadcast_ndarray = NDArrayObject::alloca_uninitialized( + let broadcast_ndarray = NDArrayObject::alloca( generator, ctx, self.dtype, diff --git a/nac3core/src/codegen/object/ndarray/functions.rs b/nac3core/src/codegen/object/ndarray/functions.rs index 5fa0e530..69d35b21 100644 --- a/nac3core/src/codegen/object/ndarray/functions.rs +++ b/nac3core/src/codegen/object/ndarray/functions.rs @@ -80,10 +80,10 @@ where let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { // Special handling for floats - let n = scalar.instance.into_float_value(); + let n = scalar.value.into_float_value(); handle_float(generator, ctx, n) } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { - let n = scalar.instance.into_int_value(); + let n = scalar.value.into_int_value(); if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() { ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap() @@ -95,7 +95,7 @@ where }; assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check - ScalarObject { instance: result.into(), dtype: ret_int_dtype } + ScalarObject { value: result.into(), dtype: ret_int_dtype } } impl<'ctx> ScalarObject<'ctx> { @@ -104,7 +104,7 @@ impl<'ctx> ScalarObject<'ctx> { /// Panic if the type is wrong. pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - self.instance.into_float_value() // self.value must be a FloatValue + self.value.into_float_value() // self.value must be a FloatValue } else { panic!("not a float type") } @@ -115,7 +115,7 @@ impl<'ctx> ScalarObject<'ctx> { /// Panic if the type is wrong. pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) { - let value = self.instance.into_int_value(); + let value = self.value.into_int_value(); debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check value } else { @@ -142,12 +142,12 @@ impl<'ctx> ScalarObject<'ctx> { let common_ty = lhs.dtype; let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) { - let lhs = lhs.instance.into_float_value(); - let rhs = rhs.instance.into_float_value(); + let lhs = lhs.value.into_float_value(); + let rhs = rhs.value.into_float_value(); ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap() } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { - let lhs = lhs.instance.into_int_value(); - let rhs = rhs.instance.into_int_value(); + let lhs = lhs.value.into_int_value(); + let rhs = rhs.value.into_int_value(); ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() } else { unsupported_type(ctx, [lhs.dtype, rhs.dtype]); @@ -266,14 +266,14 @@ impl<'ctx> ScalarObject<'ctx> { pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { // TODO: Why is the original code being so lax about i1 and i8 for the returned int type? let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { - self.instance.into_int_value() + self.value.into_int_value() } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { - let n = self.instance.into_int_value(); + let n = self.value.into_int_value(); ctx.builder .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") .unwrap() } else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); ctx.builder .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") .unwrap() @@ -281,7 +281,7 @@ impl<'ctx> ScalarObject<'ctx> { unsupported_type(ctx, [self.dtype]) }; - ScalarObject { dtype: ctx.primitives.bool, instance: result.as_basic_value_enum() } + ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() } } /// Invoke NAC3's builtin `float()`. @@ -290,21 +290,21 @@ impl<'ctx> ScalarObject<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - self.instance.into_float_value() + self.value.into_float_value() } else if ctx .unifier .unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat()) { - let n = self.instance.into_int_value(); + let n = self.value.into_int_value(); ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() } else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) { - let n = self.instance.into_int_value(); + let n = self.value.into_int_value(); ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap() } else { unsupported_type(ctx, [self.dtype]); }; - ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float } + ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } } /// Invoke NAC3's builtin `round()`. @@ -318,13 +318,13 @@ impl<'ctx> ScalarObject<'ctx> { let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); let n = llvm_intrinsics::call_float_round(ctx, n, None); ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap() } else { unsupported_type(ctx, [self.dtype, ret_int_dtype]) }; - ScalarObject { dtype: ret_int_dtype, instance: result.as_basic_value_enum() } + ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() } } /// Invoke NAC3's builtin `np_round()`. @@ -333,12 +333,12 @@ impl<'ctx> ScalarObject<'ctx> { #[must_use] pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); llvm_intrinsics::call_float_rint(ctx, n, None) } else { unsupported_type(ctx, [self.dtype]) }; - ScalarObject { dtype: ctx.primitives.float, instance: result.as_basic_value_enum() } + ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() } } /// Invoke NAC3's builtin `min()` or `max()`. @@ -360,8 +360,8 @@ impl<'ctx> ScalarObject<'ctx> { MinOrMax::Max => llvm_intrinsics::call_float_maxnum, }; let result = - function(ctx, a.instance.into_float_value(), b.instance.into_float_value(), None); - ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float } + function(ctx, a.value.into_float_value(), b.value.into_float_value(), None); + ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } } else if ctx.unifier.unioned_any( common_dtype, [unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(), @@ -371,9 +371,8 @@ impl<'ctx> ScalarObject<'ctx> { MinOrMax::Min => llvm_intrinsics::call_int_umin, MinOrMax::Max => llvm_intrinsics::call_int_umax, }; - let result = - function(ctx, a.instance.into_int_value(), b.instance.into_int_value(), None); - ScalarObject { instance: result.as_basic_value_enum(), dtype: common_dtype } + let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None); + ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype } } else { unsupported_type(ctx, [common_dtype]) } @@ -399,11 +398,11 @@ impl<'ctx> ScalarObject<'ctx> { FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, }; - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); let n = function(ctx, n, None); let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap(); - ScalarObject { dtype: ret_int_dtype, instance: n.as_basic_value_enum() } + ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() } } else { unsupported_type(ctx, [self.dtype]) } @@ -419,9 +418,9 @@ impl<'ctx> ScalarObject<'ctx> { FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, }; - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); let n = function(ctx, n, None); - ScalarObject { dtype: ctx.primitives.float, instance: n.as_basic_value_enum() } + ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() } } else { unsupported_type(ctx, [self.dtype]) } @@ -431,16 +430,16 @@ impl<'ctx> ScalarObject<'ctx> { #[must_use] pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.instance.into_float_value(); + let n = self.value.into_float_value(); let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); - ScalarObject { instance: n.into(), dtype: ctx.primitives.float } + ScalarObject { value: n.into(), dtype: ctx.primitives.float } } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { - let n = self.instance.into_int_value(); + let n = self.value.into_int_value(); let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); - ScalarObject { instance: n.into(), dtype: self.dtype } + ScalarObject { value: n.into(), dtype: self.dtype } } else { unsupported_type(ctx, [self.dtype]) } @@ -482,7 +481,7 @@ impl<'ctx> NDArrayObject<'ctx> { pextremum_index.store(ctx, zero); let first_scalar = self.get_nth(generator, ctx, zero); - ctx.builder.build_store(pextremum, first_scalar.instance).unwrap(); + ctx.builder.build_store(pextremum, first_scalar.value).unwrap(); // Find extremum let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1 @@ -495,7 +494,7 @@ impl<'ctx> NDArrayObject<'ctx> { let scalar = self.get_nth(generator, ctx, i); let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); - let old_extremum = ScalarObject { dtype: self.dtype, instance: old_extremum }; + let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); @@ -523,7 +522,7 @@ impl<'ctx> NDArrayObject<'ctx> { let extremum_index = pextremum_index.load(generator, ctx, "extremum_index"); let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap(); - let extremum = ScalarObject { dtype: self.dtype, instance: extremum }; + let extremum = ScalarObject { dtype: self.dtype, value: extremum }; (extremum, extremum_index) } diff --git a/nac3core/src/codegen/object/ndarray/indexing.rs b/nac3core/src/codegen/object/ndarray/indexing.rs index 36e230f1..2ccc2567 100644 --- a/nac3core/src/codegen/object/ndarray/indexing.rs +++ b/nac3core/src/codegen/object/ndarray/indexing.rs @@ -1,6 +1,6 @@ use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator}; -use super::{scalar::ScalarOrNDArray, NDArrayObject}; +use super::NDArrayObject; pub type NDIndexType = Byte; @@ -215,8 +215,7 @@ impl<'ctx> NDArrayObject<'ctx> { name: &str, ) -> Self { let dst_ndims = self.deduce_ndims_after_indexing_with(indexes); - let dst_ndarray = - NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, dst_ndims, name); + let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims, name); let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes); call_nac3_ndarray_index( diff --git a/nac3core/src/codegen/object/ndarray/mapping.rs b/nac3core/src/codegen/object/ndarray/mapping.rs index f2633f32..0f50e45f 100644 --- a/nac3core/src/codegen/object/ndarray/mapping.rs +++ b/nac3core/src/codegen/object/ndarray/mapping.rs @@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> { let out_ndarray = match out { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. - let result_ndarray = NDArrayObject::alloca_uninitialized( + let result_ndarray = NDArrayObject::alloca( generator, ctx, dtype, @@ -137,7 +137,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { if let Some(scalars) = all_scalars { let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index let scalar = - ScalarObject { instance: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype }; + ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype }; Ok(ScalarOrNDArray::Scalar(scalar)) } else { // Promote all input to ndarrays and map through them. diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index b5515a6f..01763869 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -22,7 +22,10 @@ use crate::{ structure::{NDArray, SimpleNDArray}, CodeGenContext, CodeGenerator, }, - toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, + toplevel::{ + helper::{create_ndims, extract_ndims}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + }, typecheck::typedef::Type, }; use indexing::RustNDIndex; @@ -71,6 +74,12 @@ impl<'ctx> NDArrayObject<'ctx> { NDArrayObject { dtype, ndims, instance: value } } + /// Forget that this is an ndarray and convert to an [`AnyObject`]. + pub fn to_any_object(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> { + let ty = self.get_ndarray_type(ctx); + AnyObject { value: self.instance.value.as_basic_value_enum(), ty } + } + /// Create a [`SimpleNDArray`] from the contents of this ndarray. /// /// This function may or may not be expensive depending on if this ndarray has contiguous data. @@ -88,6 +97,7 @@ impl<'ctx> NDArrayObject<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, item_model: Item, + name: &str, ) -> Ptr<'ctx, StructModel>> { // Sanity check on `self.dtype` and `item_model`. let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); @@ -101,7 +111,7 @@ impl<'ctx> NDArrayObject<'ctx> { let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); // Allocate and setup the resulting [`SimpleNDArray`]. - let result = simple_ndarray_model.alloca(generator, ctx, "simple_ndarray"); + let result = simple_ndarray_model.alloca(generator, ctx, name); // Set ndims and shape. let ndims = self.get_ndims(generator, ctx.ctx); @@ -155,13 +165,7 @@ impl<'ctx> NDArrayObject<'ctx> { // TODO: Check if `ndims` is consistent with that in `simple_array`? // Allocate the resulting ndarray. - let ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - dtype, - ndims, - "from_simple_ndarray", - ); + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "from_simple_ndarray"); // Set data, shape by simply copying addresses. let data = simple_ndarray @@ -178,6 +182,12 @@ impl<'ctx> NDArrayObject<'ctx> { ndarray } + /// Get the typechecker ndarray type of this [`NDArrayObject`]. + pub fn get_ndarray_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type { + let ndims = create_ndims(&mut ctx.unifier, self.ndims); + make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims)) + } + /// Get the `np.size()` of this ndarray. pub fn size( &self, @@ -243,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> { ) -> ScalarObject<'ctx> { let p = self.get_nth_pointer(generator, ctx, nth, "value"); let value = ctx.builder.build_load(p, "value").unwrap(); - ScalarObject { dtype: self.dtype, instance: value } + ScalarObject { dtype: self.dtype, value } } /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. @@ -283,7 +293,7 @@ impl<'ctx> NDArrayObject<'ctx> { /// - `ndims`: set to the value of `ndims`. /// - `shape`: allocated with an array of length `ndims` with uninitialized values. /// - `strides`: allocated with an array of length `ndims` with uninitialized values. - pub fn alloca_uninitialized( + pub fn alloca( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, @@ -318,7 +328,7 @@ impl<'ctx> NDArrayObject<'ctx> { /// Convenience function. /// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray. - pub fn alloca_uninitialized_of_type( + pub fn alloca_ndarray_type( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ndarray_ty: Type, @@ -326,11 +336,34 @@ impl<'ctx> NDArrayObject<'ctx> { ) -> Self { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); let ndims = extract_ndims(&ctx.unifier, ndims); - Self::alloca_uninitialized(generator, ctx, dtype, ndims, name) + Self::alloca(generator, ctx, dtype, ndims, name) } - /// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents - /// over. + /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_constant_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[u64], + name: &str, + ) -> Self { + let sizet_model = IntModel(SizeT); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape"); + for (i, dim) in shape.iter().enumerate() { + let dim = sizet_model.constant(generator, ctx.ctx, *dim); + dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, dim); + } + + ndarray + } + + /// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over. /// /// The new ndarray will own its data and will be C-contiguous. #[must_use] @@ -340,8 +373,7 @@ impl<'ctx> NDArrayObject<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, name: &str, ) -> Self { - let clone = - NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name); + let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, name); let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape"); clone.copy_shape_from_array(generator, ctx, shape); @@ -519,7 +551,7 @@ impl<'ctx> NDArrayObject<'ctx> { { self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| { let value = ctx.builder.build_load(p, "value").unwrap(); - let scalar = ScalarObject { dtype: self.dtype, instance: value }; + let scalar = ScalarObject { dtype: self.dtype, value }; body(generator, ctx, hooks, i, scalar) }) } @@ -588,13 +620,8 @@ impl<'ctx> NDArrayObject<'ctx> { let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); - let dst_ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - self.dtype, - new_ndims, - "reshaped_ndarray", - ); + let dst_ndarray = + NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray"); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape); let size = self.size(generator, ctx); @@ -661,13 +688,8 @@ impl<'ctx> NDArrayObject<'ctx> { // Define models let sizet_model = IntModel(SizeT); - let transposed_ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - self.dtype, - self.ndims, - "transposed_ndarray", - ); + let transposed_ndarray = + NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, "transposed_ndarray"); let num_axes = self.get_ndims(generator, ctx.ctx); @@ -686,7 +708,7 @@ impl<'ctx> NDArrayObject<'ctx> { transposed_ndarray } - /// Check if this NDArray can be used as an `out` ndarray for an operation. + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. /// /// Raise an exception if the shapes do not match. pub fn check_can_be_written_by_out( diff --git a/nac3core/src/codegen/object/ndarray/nalgebra.rs b/nac3core/src/codegen/object/ndarray/nalgebra.rs index 8b137891..0eef4707 100644 --- a/nac3core/src/codegen/object/ndarray/nalgebra.rs +++ b/nac3core/src/codegen/object/ndarray/nalgebra.rs @@ -1 +1,53 @@ +use inkwell::values::{BasicValue, BasicValueEnum}; +use crate::codegen::{model::*, structure::SimpleNDArray, CodeGenContext, CodeGenerator}; + +use super::NDArrayObject; + +pub fn perform_nalgebra_call<'ctx, 'a, const NUM_INPUTS: usize, const NUM_OUTPUTS: usize, G, F>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: [NDArrayObject<'ctx>; NUM_INPUTS], + output_ndims: [u64; NUM_OUTPUTS], + invoke_function: F, +) -> [NDArrayObject<'ctx>; NUM_OUTPUTS] +where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut CodeGenContext<'ctx, 'a>, + [BasicValueEnum<'ctx>; NUM_INPUTS], + [BasicValueEnum<'ctx>; NUM_OUTPUTS], + ), +{ + // TODO: Allow stacked inputs. See NumPy docs. + + let f64_model = FloatModel(Float64); + let simple_ndarray_model = StructModel(SimpleNDArray { item: f64_model }); + + // Prepare inputs & outputs and invoke + let inputs = inputs.map(|input| { + // Sanity check. Typechecker ensures this. + assert!(ctx.unifier.unioned(input.dtype, ctx.primitives.float)); + + input + .make_simple_ndarray(generator, ctx, FloatModel(Float64), "nalgebra_input") + .value + .as_basic_value_enum() + }); + let outputs = [simple_ndarray_model.alloca(generator, ctx, "nalgebra_output"); NUM_OUTPUTS]; + invoke_function(ctx, inputs, outputs.map(|output| output.value.as_basic_value_enum())); + + // Turn the outputs into strided NDArrays + let mut output_i = 0; + outputs.map(|output| { + let out = NDArrayObject::from_simple_ndarray( + generator, + ctx, + output, + ctx.primitives.float, + output_ndims[output_i], + ); + output_i += 1; + out + }) +} diff --git a/nac3core/src/codegen/object/ndarray/product.rs b/nac3core/src/codegen/object/ndarray/product.rs index 3a39d3a6..2b213781 100644 --- a/nac3core/src/codegen/object/ndarray/product.rs +++ b/nac3core/src/codegen/object/ndarray/product.rs @@ -55,7 +55,7 @@ impl<'ctx> NDArrayObject<'ctx> { let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape); let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape); - let dst = NDArrayObject::alloca_uninitialized( + let dst = NDArrayObject::alloca( generator, ctx, ctx.primitives.float, diff --git a/nac3core/src/codegen/object/ndarray/scalar.rs b/nac3core/src/codegen/object/ndarray/scalar.rs index 36b43ece..0625bc6a 100644 --- a/nac3core/src/codegen/object/ndarray/scalar.rs +++ b/nac3core/src/codegen/object/ndarray/scalar.rs @@ -15,7 +15,7 @@ use super::NDArrayObject; #[derive(Debug, Clone, Copy)] pub struct ScalarObject<'ctx> { pub dtype: Type, - pub instance: BasicValueEnum<'ctx>, + pub value: BasicValueEnum<'ctx>, } impl<'ctx> ScalarObject<'ctx> { @@ -31,12 +31,11 @@ impl<'ctx> ScalarObject<'ctx> { let pbyte_model = PtrModel(IntModel(Byte)); // We have to put the value on the stack to get a data pointer. - let data = ctx.builder.build_alloca(self.instance.get_type(), "as_ndarray_scalar").unwrap(); - ctx.builder.build_store(data, self.instance).unwrap(); + let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap(); + ctx.builder.build_store(data, self.value).unwrap(); let data = pbyte_model.pointer_cast(generator, ctx, data, "data"); - let ndarray = - NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray"); + let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray"); ndarray.instance.set(ctx, |f| f.data, data); ndarray } @@ -54,7 +53,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { match self { - ScalarOrNDArray::Scalar(scalar) => scalar.instance, + ScalarOrNDArray::Scalar(scalar) => scalar.value, ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), } } @@ -135,7 +134,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>( ScalarOrNDArray::NDArray(ndarray) } _ => { - let scalar = ScalarObject { dtype: object.ty, instance: object.value }; + let scalar = ScalarObject { dtype: object.ty, value: object.value }; ScalarOrNDArray::Scalar(scalar) } } diff --git a/nac3core/src/codegen/object/ndarray/shape_util.rs b/nac3core/src/codegen/object/ndarray/shape_util.rs index c64ca935..dec32382 100644 --- a/nac3core/src/codegen/object/ndarray/shape_util.rs +++ b/nac3core/src/codegen/object/ndarray/shape_util.rs @@ -48,8 +48,9 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( // Load the i-th int32 in the input sequence let int = input_sequence .instance - .get(generator, ctx, |f| f.items, "int") - .ix(generator, ctx, i.value, "int") + .get(generator, ctx, |f| f.items, "") + .offset(generator, ctx, i.value, "") + .load(generator, ctx, "") .value .into_int_value(); @@ -65,7 +66,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( (len, result) } - TypeEnum::TTuple { ty: tuple_types, .. } => { + TypeEnum::TTuple { .. } => { // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` let input_sequence = TupleObject::from_object(ctx, input_sequence); diff --git a/nac3core/src/codegen/object/tuple.rs b/nac3core/src/codegen/object/tuple.rs index 6f84faf1..a1143b7d 100644 --- a/nac3core/src/codegen/object/tuple.rs +++ b/nac3core/src/codegen/object/tuple.rs @@ -34,13 +34,13 @@ impl<'ctx> TupleObject<'ctx> { }; let value = object.value.into_struct_value(); - if value.get_type().count_fields() as usize != tys.len() { - panic!( - "Tuple type has {} item(s), but the LLVM struct value has {} field(s)", - tys.len(), - value.get_type().count_fields() - ); - } + let value_num_fields = value.get_type().count_fields() as usize; + assert!( + value_num_fields != tys.len(), + "Tuple type has {} item(s), but the LLVM struct value has {} field(s)", + tys.len(), + value_num_fields + ); TupleObject { tys: tys.clone(), value } } @@ -74,18 +74,23 @@ impl<'ctx> TupleObject<'ctx> { /// Get the `len()` of this tuple. /// /// We statically know the lengths of tuples in NAC3. + #[must_use] pub fn len(&self) -> usize { self.tys.len() } + /// Check if this tuple is an empty/unit tuple. + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Get the `i`-th (0-based) object in this tuple. pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> { - if i >= self.len() { - panic!("Tuple object with length {} have index {i}", self.len()); - } + assert!(i >= self.len(), "Tuple object with length {} have index {i}", self.len()); let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap(); let ty = self.tys[i]; - AnyObject { value, ty } + AnyObject { ty, value } } } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 68dd5fd1..25e901aa 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1827,7 +1827,7 @@ pub fn gen_stmt( let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?; cslice_model.check_value(generator, ctx.ctx, msg).unwrap() } - None => ctx.gen_string(generator, "").into(), + None => ctx.gen_string(generator, ""), }; ctx.make_assert_impl( diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7c656701..35144323 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -15,14 +15,19 @@ use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, - extern_fns, irrt, llvm_intrinsics, + extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power}, + irrt, llvm_intrinsics, + model::{IntModel, SizeT}, numpy::*, numpy_new::{self, gen_ndarray_transpose}, object::{ ndarray::{ functions::{FloorOrCeil, MinOrMax}, + nalgebra::perform_nalgebra_call, scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray}, + NDArrayObject, }, + tuple::TupleObject, AnyObject, }, stmt::exn_constructor, @@ -1109,7 +1114,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunBool => scalar.cast_to_bool(ctx), _ => unreachable!(), }; - Ok(result.instance) + Ok(result.value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1171,7 +1176,7 @@ impl<'a> BuiltinBuilder<'a> { ctx, ret_int_dtype, |generator, ctx, _i, scalar| { - Ok(scalar.round(generator, ctx, ret_int_dtype).instance) + Ok(scalar.round(generator, ctx, ret_int_dtype).value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1237,7 +1242,7 @@ impl<'a> BuiltinBuilder<'a> { ctx, int_sized, |generator, ctx, _i, scalar| { - Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).instance) + Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1638,7 +1643,7 @@ impl<'a> BuiltinBuilder<'a> { ctx.primitives.float, move |_generator, ctx, _i, scalar| { let result = scalar.np_floor_or_ceil(ctx, kind); - Ok(result.instance) + Ok(result.value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1667,7 +1672,7 @@ impl<'a> BuiltinBuilder<'a> { ctx.primitives.float, |_generator, ctx, _i, scalar| { let result = scalar.np_round(ctx); - Ok(result.instance) + Ok(result.value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1754,10 +1759,10 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - let m = ScalarObject { dtype: m_ty, instance: m_val }; - let n = ScalarObject { dtype: n_ty, instance: n_val }; + let m = ScalarObject { dtype: m_ty, value: m_val }; + let n = ScalarObject { dtype: n_ty, value: n_val }; let result = ScalarObject::min_or_max(ctx, kind, m, n); - Ok(Some(result.instance)) + Ok(Some(result.value)) }, )))), loc: None, @@ -1811,10 +1816,10 @@ impl<'a> BuiltinBuilder<'a> { .value .as_basic_value_enum(), PrimDef::FunNpMin => { - a.min_or_max(generator, ctx, MinOrMax::Min).instance.as_basic_value_enum() + a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum() } PrimDef::FunNpMax => { - a.min_or_max(generator, ctx, MinOrMax::Max).instance.as_basic_value_enum() + a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum() } _ => unreachable!(), }; @@ -1883,7 +1888,7 @@ impl<'a> BuiltinBuilder<'a> { let x2 = scalars[1]; let result = ScalarObject::min_or_max(ctx, kind, x1, x2); - Ok(result.instance) + Ok(result.value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1925,7 +1930,7 @@ impl<'a> BuiltinBuilder<'a> { generator, ctx, num_ty.ty, - |_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).instance), + |_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value), )?; Ok(Some(result.to_basic_value_enum())) }, @@ -2253,6 +2258,7 @@ impl<'a> BuiltinBuilder<'a> { ), PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { + // Function type: NDArray[float; 2] -> NDArray[float; 2] create_fn_by_codegen( self.unifier, &VarMap::new(), @@ -2263,14 +2269,22 @@ impl<'a> BuiltinBuilder<'a> { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x1 = AnyObject { value: x1_val, ty: x1_ty }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); let func = match prim { - PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky, - PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv, - PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv, + PrimDef::FunNpLinalgCholesky => extern_fns::call_np_linalg_cholesky, + PrimDef::FunNpLinalgInv => extern_fns::call_np_linalg_inv, + PrimDef::FunNpLinalgPinv => extern_fns::call_np_linalg_pinv, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + + let [out] = + perform_nalgebra_call(generator, ctx, [x1], [2], |ctx, [x1], [out]| { + func(ctx, x1, out, Some(prim.name())); + }); + + Ok(Some(out.instance.value.as_basic_value_enum())) }), ) } @@ -2279,6 +2293,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur | PrimDef::FunSpLinalgHessenberg => { + // Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 2]) let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { ty: vec![self.ndarray_float_2d, self.ndarray_float_2d], is_vararg_ctx: false, @@ -2293,22 +2308,35 @@ impl<'a> BuiltinBuilder<'a> { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x1 = AnyObject { value: x1_val, ty: x1_ty }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); let func = match prim { - PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr, - PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu, - PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur, - PrimDef::FunSpLinalgHessenberg => { - builtin_fns::call_sp_linalg_hessenberg - } + PrimDef::FunNpLinalgQr => extern_fns::call_np_linalg_qr, + PrimDef::FunSpLinalgLu => extern_fns::call_sp_linalg_lu, + PrimDef::FunSpLinalgSchur => extern_fns::call_sp_linalg_schur, + PrimDef::FunSpLinalgHessenberg => extern_fns::call_sp_linalg_hessenberg, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + + let out = perform_nalgebra_call( + generator, + ctx, + [x1], + [2, 2], + |ctx, [x1], [out1, out2]| func(ctx, x1, out1, out2, Some(prim.name())), + ); + + // Create the output tuple + let out = out.map(|o| o.to_any_object(ctx)); + let out = TupleObject::create(generator, ctx, out, prim.name()); + Ok(Some(out.value.as_basic_value_enum())) }), ) } PrimDef::FunNpLinalgSvd => { + // Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 1], NDArray[float; 2]) let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d], is_vararg_ctx: false, @@ -2323,8 +2351,30 @@ impl<'a> BuiltinBuilder<'a> { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x1 = AnyObject { ty: x1_ty, value: x1_val }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?)) + let out = perform_nalgebra_call( + generator, + ctx, + [x1], + [2, 1, 2], + |ctx, [x1], [out1, out2, out3]| { + extern_fns::call_np_linalg_svd( + ctx, + x1, + out1, + out2, + out3, + Some(prim.name()), + ); + }, + ); + + // Create the output tuple + let out = out.map(|o| o.to_any_object(ctx)); + let out = TupleObject::create(generator, ctx, out, prim.name()); + Ok(Some(out.value.as_basic_value_enum())) }), ) } @@ -2337,15 +2387,26 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x1 = AnyObject { ty: x1_ty, value: x1_val }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + + // The second argument is converted to an ndarray for implementation convenience. + // TODO: Don't do that. let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + let x2 = ScalarObject { dtype: x2_ty, value: x2_val }; + let x2 = x2.as_ndarray(generator, ctx); - Ok(Some(builtin_fns::call_np_linalg_matrix_power( + let [out] = perform_nalgebra_call( generator, ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) + [x1, x2], + [2], + |ctx, [x1, x2], [out]| { + call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name())); + }, + ); + Ok(Some(out.instance.value.as_basic_value_enum())) }), ), PrimDef::FunNpLinalgDet => create_fn_by_codegen( @@ -2357,7 +2418,22 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?)) + let x1 = AnyObject { value: x1_val, ty: x1_ty }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + + // The output is returned as a 1D ndarray, even though the result is a single float. + // It is implemented like this at the moment because it is convenient. + // TODO: Don't do that. + let [out] = + perform_nalgebra_call(generator, ctx, [x1], [1], |ctx, [x1], [out]| { + call_np_linalg_det(ctx, x1, out, Some(prim.name())); + }); + + let sizet_model = IntModel(SizeT); + let zero = sizet_model.const_0(generator, ctx.ctx); + let determinant = out.get_nth(generator, ctx, zero); + + Ok(Some(determinant.value)) }), ), _ => unreachable!(), diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 2b4ea43b..63f6173d 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,12 +1,11 @@ use crate::{ - symbol_resolver::SymbolValue, toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, }, }; -use itertools::{Either, Itertools}; +use itertools::Itertools; /// Creates a `ndarray` [`Type`] with the given type arguments. ///