forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: done

This commit is contained in:
lyken 2024-08-14 15:56:59 +08:00
parent 2fbe981701
commit fd78f7a0e8
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
26 changed files with 527 additions and 796 deletions

View File

@ -1,17 +1,12 @@
use inkwell::types::BasicTypeEnum; use inkwell::values::{BasicValueEnum, IntValue};
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
use inkwell::IntPredicate; use inkwell::IntPredicate;
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{ use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, RangeValue, TypedArrayLikeAccessor};
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor,
};
use crate::codegen::expr::destructure_range; use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range; use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Type, TypeEnum}; use crate::typecheck::typedef::{Type, TypeEnum};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// Shorthand for [`unreachable!()`] when a type of argument is not supported.
@ -84,505 +79,3 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
} }
}) })
} }
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
fn build_output_struct<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: Vec<BasicValueEnum<'ctx>>,
) -> PointerValue<'ctx> {
let field_ty =
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
out_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
out_ptr
}
/// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_matrix_power` linalg function
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
todo!();
/*
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
*/
}
/// Invokes the `np_linalg_det` linalg function
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")
.map(Into::into)
.unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}

View File

@ -32,10 +32,7 @@ use crate::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{ values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
StructValue,
},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
@ -314,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.raise_exn( self.raise_exn(
generator, generator,
"0:NotImplementedError", "0:NotImplementedError",
msg.into(), msg,
[None, None, None], [None, None, None],
self.current_loc, self.current_loc,
); );
@ -639,7 +636,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap())); params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap()));
let err_msg = self.gen_string(generator, err_msg); let err_msg = self.gen_string(generator, err_msg);
self.make_assert_impl(generator, cond, err_name, err_msg.into(), params, loc); self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
} }
pub fn make_assert_impl<G: CodeGenerator + ?Sized>( pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
@ -1574,9 +1571,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
gen_binop_expr_with_values( gen_binop_expr_with_values(
generator, generator,
ctx, ctx,
(&Some(left.dtype), left.instance), (&Some(left.dtype), left.value),
op, op,
(&Some(right.dtype), right.instance), (&Some(right.dtype), right.value),
ctx.current_loc, ctx.current_loc,
)? )?
.unwrap() .unwrap()
@ -2689,7 +2686,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ctx.raise_exn( ctx.raise_exn(
generator, generator,
"0:UnwrapNoneError", "0:UnwrapNoneError",
err_msg.into(), err_msg,
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );

View File

@ -2,7 +2,6 @@ use crate::symbol_resolver::SymbolResolver;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
mod test; mod test;
pub mod util;
use super::model::*; use super::model::*;
use super::object::ndarray::broadcast::ShapeEntry; use super::object::ndarray::broadcast::ShapeEntry;
@ -17,6 +16,7 @@ use super::{
}; };
use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use function::{get_sizet_dependent_function_name, CallFunction};
use inkwell::values::BasicValue; use inkwell::values::BasicValue;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -29,8 +29,6 @@ use inkwell::{
}; };
use itertools::Either; use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use util::function::CallFunction;
use util::get_sizet_dependent_function_name;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt(ctx: &Context) -> Module {

View File

@ -1,107 +0,0 @@
use crate::codegen::{CodeGenContext, CodeGenerator};
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
pub mod function {
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
use inkwell::{
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// Helper structure to reduce IRRT Inkwell function call boilerplate
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
}
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn begin(
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
name: &'c str,
) -> Self {
CallFunction { generator, ctx, name, args: Vec::new() }
}
/// Push a call argument to the function call.
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Call the function and expect the function to return a value of type of `return_model`.
#[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
#[must_use]
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
self.returning(name, M::default())
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
}
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function, declare the function if it doesn't exist - it will be defined by other
// components of NAC3.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let fn_type = make_fn_type(&tys);
self.ctx.module.add_function(self.name, fn_type, None)
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}
}

View File

@ -21,6 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
type Type: BasicType<'ctx>; type Type: BasicType<'ctx>;
/// Return the [`BasicType`] of this model. /// Return the [`BasicType`] of this model.
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
/// Check if a [`BasicType`] is the same type of this model. /// Check if a [`BasicType`] is the same type of this model.

View File

@ -0,0 +1,88 @@
use std::fmt;
use inkwell::{context::Context, types::FloatType, values::FloatValue};
use crate::codegen::CodeGenerator;
use super::*;
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Float32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Float64;
impl<'ctx> FloatKind<'ctx> for Float32 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f32_type()
}
}
impl<'ctx> FloatKind<'ctx> for Float64 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f64_type()
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> FloatType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FloatModel<N>(pub N);
pub type Float<'ctx, N> = Instance<'ctx, FloatModel<N>>;
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for FloatModel<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_float_type(generator, ctx)
}
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = FloatType::try_from(ty) else {
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
};
let exp_ty = self.0.get_float_type(generator, ctx);
// TODO: Inkwell does not have get_bit_width for FloatType?
// TODO: Quick hack for now, but does this actually work?
if ty != exp_ty {
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
}
Ok(())
}
}

View File

@ -0,0 +1,125 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// A structure to construct & call an LLVM function.
///
/// This is a helper to reduce IRRT Inkwell function call boilerplate
// TODO: Remove the lifetimes somehow? There is 4 of them.
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
/// LLVM function Attributes
attrs: Vec<&'static str>,
}
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
}
/// Push a list of LLVM function attributes to the function declaration.
#[must_use]
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
self.attrs = attrs;
self
}
/// Push a call argument to the function call.
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Call the function and expect the function to return a value of type of `return_model`.
#[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
#[must_use]
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
self.returning(name, M::default())
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
}
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
// Declare the function if it doesn't exist.
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let func_type = make_fn_type(&tys);
let func = self.ctx.module.add_function(self.name, func_type, None);
for attr in &self.attrs {
func.add_attribute(
AttributeLoc::Function,
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}

View File

@ -96,7 +96,6 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
type Value = IntValue<'ctx>; type Value = IntValue<'ctx>;
type Type = IntType<'ctx>; type Type = IntType<'ctx>;
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_int_type(generator, ctx) self.0.get_int_type(generator, ctx)
} }

View File

@ -1,5 +1,7 @@
mod any; mod any;
mod core; mod core;
mod float;
pub mod function;
mod int; mod int;
mod ptr; mod ptr;
mod structure; mod structure;
@ -7,6 +9,7 @@ pub mod util;
pub use any::*; pub use any::*;
pub use core::*; pub use core::*;
pub use float::*;
pub use int::*; pub use int::*;
pub use ptr::*; pub use ptr::*;
pub use structure::*; pub use structure::*;

View File

@ -89,15 +89,17 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, Element> {
self.model.check_value(generator, ctx.ctx, new_ptr).unwrap() self.model.check_value(generator, ctx.ctx, new_ptr).unwrap()
} }
// Load the `i`-th element (0-based) on the array with [`inkwell::builder::Builder::build_in_bounds_gep`]. /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
pub fn ix<G: CodeGenerator + ?Sized>( #[must_use]
pub fn offset_const<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
i: IntValue<'ctx>, offset: u64,
name: &str, name: &str,
) -> Instance<'ctx, Element> { ) -> Ptr<'ctx, Element> {
self.offset(generator, ctx, i, name).load(generator, ctx, name) let offset = ctx.ctx.i32_type().const_int(offset, false);
self.offset(generator, ctx, offset, name)
} }
/// Load the value with [`inkwell::builder::Builder::build_load`]. /// Load the value with [`inkwell::builder::Builder::build_load`].

View File

@ -94,8 +94,7 @@ where
{ {
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape); let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
let ndarray = let ndarray = NDArrayObject::alloca_ndarray_type(generator, ctx, ndarray_ty, "ndarray");
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
// Validate `shape` // Validate `shape`
let ndims = ndarray.get_ndims(generator, ctx.ctx); let ndims = ndarray.get_ndims(generator, ctx.ctx);
@ -321,7 +320,7 @@ pub fn gen_ndarray_arange<'ctx>(
let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim"); let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim");
// Allocate the resulting ndarray // Allocate the resulting ndarray
let ndarray = NDArrayObject::alloca_uninitialized( let ndarray = NDArrayObject::alloca(
generator, generator,
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
@ -385,20 +384,17 @@ pub fn gen_ndarray_shape<'ctx>(
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Define models
let sizet_model = IntModel(SizeT);
// Process ndarray // Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut objects = Vec::with_capacity(ndarray.ndims as usize); let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims { for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
let dim = ndarray let dim = ndarray
.instance .instance
.get(generator, ctx, |f| f.shape, "") .get(generator, ctx, |f| f.shape, "")
.ix(generator, ctx, i.value, "dim"); .offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects objects
@ -427,20 +423,17 @@ pub fn gen_ndarray_strides<'ctx>(
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Define models
let sizet_model = IntModel(SizeT);
// Process ndarray // Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut objects = Vec::with_capacity(ndarray.ndims as usize); let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims { for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
let dim = ndarray let dim = ndarray
.instance .instance
.get(generator, ctx, |f| f.strides, "") .get(generator, ctx, |f| f.strides, "")
.ix(generator, ctx, i.value, "dim"); .offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects objects
@ -524,7 +517,7 @@ pub fn gen_ndarray_array<'ctx>(
// We simply make the output ndarray's ndims correct with `atleast_nd`. // We simply make the output ndarray's ndims correct with `atleast_nd`.
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let output_ndims = extract_ndims(&mut ctx.unifier, ndims); let output_ndims = extract_ndims(&ctx.unifier, ndims);
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8 let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
let copy = copy.truncate(generator, ctx, Bool, "copy_bool"); let copy = copy.truncate(generator, ctx, Bool, "copy_bool");

View File

@ -1,5 +1,3 @@
use inkwell::values::BasicValue;
use crate::{ use crate::{
codegen::{model::*, structure::List, CodeGenContext, CodeGenerator}, codegen::{model::*, structure::List, CodeGenContext, CodeGenerator},
typecheck::typedef::{iter_type_vars, Type, TypeEnum}, typecheck::typedef::{iter_type_vars, Type, TypeEnum},

View File

@ -41,13 +41,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape"); let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape");
call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape); call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape);
let ndarray = NDArrayObject::alloca_uninitialized( let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int, "ndarray_from_list");
generator,
ctx,
dtype,
ndims_int,
"ndarray_from_list",
);
ndarray.copy_shape_from_array(generator, ctx, shape); ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx); ndarray.create_data(generator, ctx);
@ -73,13 +67,8 @@ impl<'ctx> NDArrayObject<'ctx> {
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
if ndims == 1 { if ndims == 1 {
// `list` is not nested, does not need to copy. // `list` is not nested, does not need to copy.
let ndarray = NDArrayObject::alloca_uninitialized( let ndarray =
generator, NDArrayObject::alloca(generator, ctx, dtype, 1, "ndarray_from_list_no_copy");
ctx,
dtype,
1,
"ndarray_from_list_no_copy",
);
// Set data // Set data
let data = list.get_opaque_items_ptr(generator, ctx); let data = list.get_opaque_items_ptr(generator, ctx);

View File

@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> {
target_ndims: u64, target_ndims: u64,
target_shape: Ptr<'ctx, IntModel<SizeT>>, target_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self { ) -> Self {
let broadcast_ndarray = NDArrayObject::alloca_uninitialized( let broadcast_ndarray = NDArrayObject::alloca(
generator, generator,
ctx, ctx,
self.dtype, self.dtype,

View File

@ -80,10 +80,10 @@ where
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats // Special handling for floats
let n = scalar.instance.into_float_value(); let n = scalar.value.into_float_value();
handle_float(generator, ctx, n) handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.instance.into_int_value(); let n = scalar.value.into_int_value();
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() { if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap() ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
@ -95,7 +95,7 @@ where
}; };
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { instance: result.into(), dtype: ret_int_dtype } ScalarObject { value: result.into(), dtype: ret_int_dtype }
} }
impl<'ctx> ScalarObject<'ctx> { impl<'ctx> ScalarObject<'ctx> {
@ -104,7 +104,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong. /// Panic if the type is wrong.
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> { pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.instance.into_float_value() // self.value must be a FloatValue self.value.into_float_value() // self.value must be a FloatValue
} else { } else {
panic!("not a float type") panic!("not a float type")
} }
@ -115,7 +115,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong. /// Panic if the type is wrong.
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) { if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
let value = self.instance.into_int_value(); let value = self.value.into_int_value();
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
value value
} else { } else {
@ -142,12 +142,12 @@ impl<'ctx> ScalarObject<'ctx> {
let common_ty = lhs.dtype; let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) { let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.instance.into_float_value(); let lhs = lhs.value.into_float_value();
let rhs = rhs.instance.into_float_value(); let rhs = rhs.value.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap() ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.instance.into_int_value(); let lhs = lhs.value.into_int_value();
let rhs = rhs.instance.into_int_value(); let rhs = rhs.value.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else { } else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]); unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
@ -266,14 +266,14 @@ impl<'ctx> ScalarObject<'ctx> {
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type? // TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.instance.into_int_value() self.value.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.instance.into_int_value(); let n = self.value.into_int_value();
ctx.builder ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap() .unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { } else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
ctx.builder ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap() .unwrap()
@ -281,7 +281,7 @@ impl<'ctx> ScalarObject<'ctx> {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
}; };
ScalarObject { dtype: ctx.primitives.bool, instance: result.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `float()`. /// Invoke NAC3's builtin `float()`.
@ -290,21 +290,21 @@ impl<'ctx> ScalarObject<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.instance.into_float_value() self.value.into_float_value()
} else if ctx } else if ctx
.unifier .unifier
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat()) .unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
{ {
let n = self.instance.into_int_value(); let n = self.value.into_int_value();
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
let n = self.instance.into_int_value(); let n = self.value.into_int_value();
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap() ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
} else { } else {
unsupported_type(ctx, [self.dtype]); unsupported_type(ctx, [self.dtype]);
}; };
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float } ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} }
/// Invoke NAC3's builtin `round()`. /// Invoke NAC3's builtin `round()`.
@ -318,13 +318,13 @@ impl<'ctx> ScalarObject<'ctx> {
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None); let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap() ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
} else { } else {
unsupported_type(ctx, [self.dtype, ret_int_dtype]) unsupported_type(ctx, [self.dtype, ret_int_dtype])
}; };
ScalarObject { dtype: ret_int_dtype, instance: result.as_basic_value_enum() } ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `np_round()`. /// Invoke NAC3's builtin `np_round()`.
@ -333,12 +333,12 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use] #[must_use]
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None) llvm_intrinsics::call_float_rint(ctx, n, None)
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
}; };
ScalarObject { dtype: ctx.primitives.float, instance: result.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `min()` or `max()`. /// Invoke NAC3's builtin `min()` or `max()`.
@ -360,8 +360,8 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Max => llvm_intrinsics::call_float_maxnum, MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
}; };
let result = let result =
function(ctx, a.instance.into_float_value(), b.instance.into_float_value(), None); function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float } ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any( } else if ctx.unifier.unioned_any(
common_dtype, common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(), [unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
@ -371,9 +371,8 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Min => llvm_intrinsics::call_int_umin, MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax, MinOrMax::Max => llvm_intrinsics::call_int_umax,
}; };
let result = let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
function(ctx, a.instance.into_int_value(), b.instance.into_int_value(), None); ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
ScalarObject { instance: result.as_basic_value_enum(), dtype: common_dtype }
} else { } else {
unsupported_type(ctx, [common_dtype]) unsupported_type(ctx, [common_dtype])
} }
@ -399,11 +398,11 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
}; };
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
let n = function(ctx, n, None); let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap(); let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: ret_int_dtype, instance: n.as_basic_value_enum() } ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -419,9 +418,9 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
}; };
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
let n = function(ctx, n, None); let n = function(ctx, n, None);
ScalarObject { dtype: ctx.primitives.float, instance: n.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -431,16 +430,16 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use] #[must_use]
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.instance.into_float_value(); let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
ScalarObject { instance: n.into(), dtype: ctx.primitives.float } ScalarObject { value: n.into(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.instance.into_int_value(); let n = self.value.into_int_value();
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
ScalarObject { instance: n.into(), dtype: self.dtype } ScalarObject { value: n.into(), dtype: self.dtype }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -482,7 +481,7 @@ impl<'ctx> NDArrayObject<'ctx> {
pextremum_index.store(ctx, zero); pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero); let first_scalar = self.get_nth(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.instance).unwrap(); ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
// Find extremum // Find extremum
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1 let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
@ -495,7 +494,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let scalar = self.get_nth(generator, ctx, i); let scalar = self.get_nth(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, instance: old_extremum }; let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
@ -523,7 +522,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index"); let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap(); let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, instance: extremum }; let extremum = ScalarObject { dtype: self.dtype, value: extremum };
(extremum, extremum_index) (extremum, extremum_index)
} }

View File

@ -1,6 +1,6 @@
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator}; use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
use super::{scalar::ScalarOrNDArray, NDArrayObject}; use super::NDArrayObject;
pub type NDIndexType = Byte; pub type NDIndexType = Byte;
@ -215,8 +215,7 @@ impl<'ctx> NDArrayObject<'ctx> {
name: &str, name: &str,
) -> Self { ) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indexes); let dst_ndims = self.deduce_ndims_after_indexing_with(indexes);
let dst_ndarray = let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims, name);
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, dst_ndims, name);
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes); let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes);
call_nac3_ndarray_index( call_nac3_ndarray_index(

View File

@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let out_ndarray = match out { let out_ndarray = match out {
NDArrayOut::NewNDArray { dtype } => { NDArrayOut::NewNDArray { dtype } => {
// Create a new ndarray based on the broadcast shape. // Create a new ndarray based on the broadcast shape.
let result_ndarray = NDArrayObject::alloca_uninitialized( let result_ndarray = NDArrayObject::alloca(
generator, generator,
ctx, ctx,
dtype, dtype,
@ -137,7 +137,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
if let Some(scalars) = all_scalars { if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar = let scalar =
ScalarObject { instance: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype }; ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
Ok(ScalarOrNDArray::Scalar(scalar)) Ok(ScalarOrNDArray::Scalar(scalar))
} else { } else {
// Promote all input to ndarrays and map through them. // Promote all input to ndarrays and map through them.

View File

@ -22,7 +22,10 @@ use crate::{
structure::{NDArray, SimpleNDArray}, structure::{NDArray, SimpleNDArray},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, toplevel::{
helper::{create_ndims, extract_ndims},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::Type, typecheck::typedef::Type,
}; };
use indexing::RustNDIndex; use indexing::RustNDIndex;
@ -71,6 +74,12 @@ impl<'ctx> NDArrayObject<'ctx> {
NDArrayObject { dtype, ndims, instance: value } NDArrayObject { dtype, ndims, instance: value }
} }
/// Forget that this is an ndarray and convert to an [`AnyObject`].
pub fn to_any_object(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let ty = self.get_ndarray_type(ctx);
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
}
/// Create a [`SimpleNDArray`] from the contents of this ndarray. /// Create a [`SimpleNDArray`] from the contents of this ndarray.
/// ///
/// This function may or may not be expensive depending on if this ndarray has contiguous data. /// This function may or may not be expensive depending on if this ndarray has contiguous data.
@ -88,6 +97,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
item_model: Item, item_model: Item,
name: &str,
) -> Ptr<'ctx, StructModel<SimpleNDArray<Item>>> { ) -> Ptr<'ctx, StructModel<SimpleNDArray<Item>>> {
// Sanity check on `self.dtype` and `item_model`. // Sanity check on `self.dtype` and `item_model`.
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
@ -101,7 +111,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
// Allocate and setup the resulting [`SimpleNDArray`]. // Allocate and setup the resulting [`SimpleNDArray`].
let result = simple_ndarray_model.alloca(generator, ctx, "simple_ndarray"); let result = simple_ndarray_model.alloca(generator, ctx, name);
// Set ndims and shape. // Set ndims and shape.
let ndims = self.get_ndims(generator, ctx.ctx); let ndims = self.get_ndims(generator, ctx.ctx);
@ -155,13 +165,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// TODO: Check if `ndims` is consistent with that in `simple_array`? // TODO: Check if `ndims` is consistent with that in `simple_array`?
// Allocate the resulting ndarray. // Allocate the resulting ndarray.
let ndarray = NDArrayObject::alloca_uninitialized( let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "from_simple_ndarray");
generator,
ctx,
dtype,
ndims,
"from_simple_ndarray",
);
// Set data, shape by simply copying addresses. // Set data, shape by simply copying addresses.
let data = simple_ndarray let data = simple_ndarray
@ -178,6 +182,12 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray ndarray
} }
/// Get the typechecker ndarray type of this [`NDArrayObject`].
pub fn get_ndarray_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
}
/// Get the `np.size()` of this ndarray. /// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>( pub fn size<G: CodeGenerator + ?Sized>(
&self, &self,
@ -243,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> ScalarObject<'ctx> { ) -> ScalarObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value"); let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap(); let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, instance: value } ScalarObject { dtype: self.dtype, value }
} }
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
@ -283,7 +293,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// - `ndims`: set to the value of `ndims`. /// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values. /// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values. /// - `strides`: allocated with an array of length `ndims` with uninitialized values.
pub fn alloca_uninitialized<G: CodeGenerator + ?Sized>( pub fn alloca<G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type, dtype: Type,
@ -318,7 +328,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Convenience function. /// Convenience function.
/// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray. /// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray.
pub fn alloca_uninitialized_of_type<G: CodeGenerator + ?Sized>( pub fn alloca_ndarray_type<G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ty: Type, ndarray_ty: Type,
@ -326,11 +336,34 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> Self { ) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name) Self::alloca(generator, ctx, dtype, ndims, name)
} }
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
/// over. ///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[u64],
name: &str,
) -> Self {
let sizet_model = IntModel(SizeT);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
let dim = sizet_model.constant(generator, ctx.ctx, *dim);
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, dim);
}
ndarray
}
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
/// ///
/// The new ndarray will own its data and will be C-contiguous. /// The new ndarray will own its data and will be C-contiguous.
#[must_use] #[must_use]
@ -340,8 +373,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: &str, name: &str,
) -> Self { ) -> Self {
let clone = let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, name);
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape"); let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape); clone.copy_shape_from_array(generator, ctx, shape);
@ -519,7 +551,7 @@ impl<'ctx> NDArrayObject<'ctx> {
{ {
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| { self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap(); let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, instance: value }; let scalar = ScalarObject { dtype: self.dtype, value };
body(generator, ctx, hooks, i, scalar) body(generator, ctx, hooks, i, scalar)
}) })
} }
@ -588,13 +620,8 @@ impl<'ctx> NDArrayObject<'ctx> {
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray = NDArrayObject::alloca_uninitialized( let dst_ndarray =
generator, NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
ctx,
self.dtype,
new_ndims,
"reshaped_ndarray",
);
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
let size = self.size(generator, ctx); let size = self.size(generator, ctx);
@ -661,13 +688,8 @@ impl<'ctx> NDArrayObject<'ctx> {
// Define models // Define models
let sizet_model = IntModel(SizeT); let sizet_model = IntModel(SizeT);
let transposed_ndarray = NDArrayObject::alloca_uninitialized( let transposed_ndarray =
generator, NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, "transposed_ndarray");
ctx,
self.dtype,
self.ndims,
"transposed_ndarray",
);
let num_axes = self.get_ndims(generator, ctx.ctx); let num_axes = self.get_ndims(generator, ctx.ctx);
@ -686,7 +708,7 @@ impl<'ctx> NDArrayObject<'ctx> {
transposed_ndarray transposed_ndarray
} }
/// Check if this NDArray can be used as an `out` ndarray for an operation. /// Check if this `NDArray` can be used as an `out` ndarray for an operation.
/// ///
/// Raise an exception if the shapes do not match. /// Raise an exception if the shapes do not match.
pub fn check_can_be_written_by_out<G: CodeGenerator + ?Sized>( pub fn check_can_be_written_by_out<G: CodeGenerator + ?Sized>(

View File

@ -1 +1,53 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::codegen::{model::*, structure::SimpleNDArray, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub fn perform_nalgebra_call<'ctx, 'a, const NUM_INPUTS: usize, const NUM_OUTPUTS: usize, G, F>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: [NDArrayObject<'ctx>; NUM_INPUTS],
output_ndims: [u64; NUM_OUTPUTS],
invoke_function: F,
) -> [NDArrayObject<'ctx>; NUM_OUTPUTS]
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
[BasicValueEnum<'ctx>; NUM_INPUTS],
[BasicValueEnum<'ctx>; NUM_OUTPUTS],
),
{
// TODO: Allow stacked inputs. See NumPy docs.
let f64_model = FloatModel(Float64);
let simple_ndarray_model = StructModel(SimpleNDArray { item: f64_model });
// Prepare inputs & outputs and invoke
let inputs = inputs.map(|input| {
// Sanity check. Typechecker ensures this.
assert!(ctx.unifier.unioned(input.dtype, ctx.primitives.float));
input
.make_simple_ndarray(generator, ctx, FloatModel(Float64), "nalgebra_input")
.value
.as_basic_value_enum()
});
let outputs = [simple_ndarray_model.alloca(generator, ctx, "nalgebra_output"); NUM_OUTPUTS];
invoke_function(ctx, inputs, outputs.map(|output| output.value.as_basic_value_enum()));
// Turn the outputs into strided NDArrays
let mut output_i = 0;
outputs.map(|output| {
let out = NDArrayObject::from_simple_ndarray(
generator,
ctx,
output,
ctx.primitives.float,
output_ndims[output_i],
);
output_i += 1;
out
})
}

View File

@ -55,7 +55,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape); let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape); let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
let dst = NDArrayObject::alloca_uninitialized( let dst = NDArrayObject::alloca(
generator, generator,
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,

View File

@ -15,7 +15,7 @@ use super::NDArrayObject;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> { pub struct ScalarObject<'ctx> {
pub dtype: Type, pub dtype: Type,
pub instance: BasicValueEnum<'ctx>, pub value: BasicValueEnum<'ctx>,
} }
impl<'ctx> ScalarObject<'ctx> { impl<'ctx> ScalarObject<'ctx> {
@ -31,12 +31,11 @@ impl<'ctx> ScalarObject<'ctx> {
let pbyte_model = PtrModel(IntModel(Byte)); let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer. // We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(self.instance.get_type(), "as_ndarray_scalar").unwrap(); let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.instance).unwrap(); ctx.builder.build_store(data, self.value).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data"); let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray = let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray");
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray");
ndarray.instance.set(ctx, |f| f.data, data); ndarray.instance.set(ctx, |f| f.data, data);
ndarray ndarray
} }
@ -54,7 +53,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
#[must_use] #[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self { match self {
ScalarOrNDArray::Scalar(scalar) => scalar.instance, ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
} }
} }
@ -135,7 +134,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
ScalarOrNDArray::NDArray(ndarray) ScalarOrNDArray::NDArray(ndarray)
} }
_ => { _ => {
let scalar = ScalarObject { dtype: object.ty, instance: object.value }; let scalar = ScalarObject { dtype: object.ty, value: object.value };
ScalarOrNDArray::Scalar(scalar) ScalarOrNDArray::Scalar(scalar)
} }
} }

View File

@ -48,8 +48,9 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
// Load the i-th int32 in the input sequence // Load the i-th int32 in the input sequence
let int = input_sequence let int = input_sequence
.instance .instance
.get(generator, ctx, |f| f.items, "int") .get(generator, ctx, |f| f.items, "")
.ix(generator, ctx, i.value, "int") .offset(generator, ctx, i.value, "")
.load(generator, ctx, "")
.value .value
.into_int_value(); .into_int_value();
@ -65,7 +66,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
(len, result) (len, result)
} }
TypeEnum::TTuple { ty: tuple_types, .. } => { TypeEnum::TTuple { .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_sequence = TupleObject::from_object(ctx, input_sequence); let input_sequence = TupleObject::from_object(ctx, input_sequence);

View File

@ -34,13 +34,13 @@ impl<'ctx> TupleObject<'ctx> {
}; };
let value = object.value.into_struct_value(); let value = object.value.into_struct_value();
if value.get_type().count_fields() as usize != tys.len() { let value_num_fields = value.get_type().count_fields() as usize;
panic!( assert!(
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)", value_num_fields != tys.len(),
tys.len(), "Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
value.get_type().count_fields() tys.len(),
); value_num_fields
} );
TupleObject { tys: tys.clone(), value } TupleObject { tys: tys.clone(), value }
} }
@ -74,18 +74,23 @@ impl<'ctx> TupleObject<'ctx> {
/// Get the `len()` of this tuple. /// Get the `len()` of this tuple.
/// ///
/// We statically know the lengths of tuples in NAC3. /// We statically know the lengths of tuples in NAC3.
#[must_use]
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.tys.len() self.tys.len()
} }
/// Check if this tuple is an empty/unit tuple.
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get the `i`-th (0-based) object in this tuple. /// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> { pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
if i >= self.len() { assert!(i >= self.len(), "Tuple object with length {} have index {i}", self.len());
panic!("Tuple object with length {} have index {i}", self.len());
}
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap(); let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i]; let ty = self.tys[i];
AnyObject { value, ty } AnyObject { ty, value }
} }
} }

View File

@ -1827,7 +1827,7 @@ pub fn gen_stmt<G: CodeGenerator>(
let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?; let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?;
cslice_model.check_value(generator, ctx.ctx, msg).unwrap() cslice_model.check_value(generator, ctx.ctx, msg).unwrap()
} }
None => ctx.gen_string(generator, "").into(), None => ctx.gen_string(generator, ""),
}; };
ctx.make_assert_impl( ctx.make_assert_impl(

View File

@ -15,14 +15,19 @@ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ProxyValue, RangeValue}, classes::{ProxyValue, RangeValue},
extern_fns, irrt, llvm_intrinsics, extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
irrt, llvm_intrinsics,
model::{IntModel, SizeT},
numpy::*, numpy::*,
numpy_new::{self, gen_ndarray_transpose}, numpy_new::{self, gen_ndarray_transpose},
object::{ object::{
ndarray::{ ndarray::{
functions::{FloorOrCeil, MinOrMax}, functions::{FloorOrCeil, MinOrMax},
nalgebra::perform_nalgebra_call,
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray}, scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
NDArrayObject,
}, },
tuple::TupleObject,
AnyObject, AnyObject,
}, },
stmt::exn_constructor, stmt::exn_constructor,
@ -1109,7 +1114,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunBool => scalar.cast_to_bool(ctx), PrimDef::FunBool => scalar.cast_to_bool(ctx),
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(result.instance) Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1171,7 +1176,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
ret_int_dtype, ret_int_dtype,
|generator, ctx, _i, scalar| { |generator, ctx, _i, scalar| {
Ok(scalar.round(generator, ctx, ret_int_dtype).instance) Ok(scalar.round(generator, ctx, ret_int_dtype).value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1237,7 +1242,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
int_sized, int_sized,
|generator, ctx, _i, scalar| { |generator, ctx, _i, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).instance) Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1638,7 +1643,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float, ctx.primitives.float,
move |_generator, ctx, _i, scalar| { move |_generator, ctx, _i, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind); let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.instance) Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1667,7 +1672,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float, ctx.primitives.float,
|_generator, ctx, _i, scalar| { |_generator, ctx, _i, scalar| {
let result = scalar.np_round(ctx); let result = scalar.np_round(ctx);
Ok(result.instance) Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1754,10 +1759,10 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(), _ => unreachable!(),
}; };
let m = ScalarObject { dtype: m_ty, instance: m_val }; let m = ScalarObject { dtype: m_ty, value: m_val };
let n = ScalarObject { dtype: n_ty, instance: n_val }; let n = ScalarObject { dtype: n_ty, value: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n); let result = ScalarObject::min_or_max(ctx, kind, m, n);
Ok(Some(result.instance)) Ok(Some(result.value))
}, },
)))), )))),
loc: None, loc: None,
@ -1811,10 +1816,10 @@ impl<'a> BuiltinBuilder<'a> {
.value .value
.as_basic_value_enum(), .as_basic_value_enum(),
PrimDef::FunNpMin => { PrimDef::FunNpMin => {
a.min_or_max(generator, ctx, MinOrMax::Min).instance.as_basic_value_enum() a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum()
} }
PrimDef::FunNpMax => { PrimDef::FunNpMax => {
a.min_or_max(generator, ctx, MinOrMax::Max).instance.as_basic_value_enum() a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum()
} }
_ => unreachable!(), _ => unreachable!(),
}; };
@ -1883,7 +1888,7 @@ impl<'a> BuiltinBuilder<'a> {
let x2 = scalars[1]; let x2 = scalars[1];
let result = ScalarObject::min_or_max(ctx, kind, x1, x2); let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.instance) Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1925,7 +1930,7 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
num_ty.ty, num_ty.ty,
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).instance), |_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}, },
@ -2253,6 +2258,7 @@ impl<'a> BuiltinBuilder<'a> {
), ),
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
// Function type: NDArray[float; 2] -> NDArray[float; 2]
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
@ -2263,14 +2269,22 @@ impl<'a> BuiltinBuilder<'a> {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { value: x1_val, ty: x1_ty };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
let func = match prim { let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky, PrimDef::FunNpLinalgCholesky => extern_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv, PrimDef::FunNpLinalgInv => extern_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv, PrimDef::FunNpLinalgPinv => extern_fns::call_np_linalg_pinv,
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
let [out] =
perform_nalgebra_call(generator, ctx, [x1], [2], |ctx, [x1], [out]| {
func(ctx, x1, out, Some(prim.name()));
});
Ok(Some(out.instance.value.as_basic_value_enum()))
}), }),
) )
} }
@ -2279,6 +2293,7 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur | PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => { | PrimDef::FunSpLinalgHessenberg => {
// Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 2])
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d], ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
is_vararg_ctx: false, is_vararg_ctx: false,
@ -2293,22 +2308,35 @@ impl<'a> BuiltinBuilder<'a> {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { value: x1_val, ty: x1_ty };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
let func = match prim { let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr, PrimDef::FunNpLinalgQr => extern_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu, PrimDef::FunSpLinalgLu => extern_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur, PrimDef::FunSpLinalgSchur => extern_fns::call_sp_linalg_schur,
PrimDef::FunSpLinalgHessenberg => { PrimDef::FunSpLinalgHessenberg => extern_fns::call_sp_linalg_hessenberg,
builtin_fns::call_sp_linalg_hessenberg
}
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
let out = perform_nalgebra_call(
generator,
ctx,
[x1],
[2, 2],
|ctx, [x1], [out1, out2]| func(ctx, x1, out1, out2, Some(prim.name())),
);
// Create the output tuple
let out = out.map(|o| o.to_any_object(ctx));
let out = TupleObject::create(generator, ctx, out, prim.name());
Ok(Some(out.value.as_basic_value_enum()))
}), }),
) )
} }
PrimDef::FunNpLinalgSvd => { PrimDef::FunNpLinalgSvd => {
// Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 1], NDArray[float; 2])
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d], ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
is_vararg_ctx: false, is_vararg_ctx: false,
@ -2323,8 +2351,30 @@ impl<'a> BuiltinBuilder<'a> {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?)) let out = perform_nalgebra_call(
generator,
ctx,
[x1],
[2, 1, 2],
|ctx, [x1], [out1, out2, out3]| {
extern_fns::call_np_linalg_svd(
ctx,
x1,
out1,
out2,
out3,
Some(prim.name()),
);
},
);
// Create the output tuple
let out = out.map(|o| o.to_any_object(ctx));
let out = TupleObject::create(generator, ctx, out, prim.name());
Ok(Some(out.value.as_basic_value_enum()))
}), }),
) )
} }
@ -2337,15 +2387,26 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
// The second argument is converted to an ndarray for implementation convenience.
// TODO: Don't do that.
let x2_ty = fun.0.args[1].ty; let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let x2 = ScalarObject { dtype: x2_ty, value: x2_val };
let x2 = x2.as_ndarray(generator, ctx);
Ok(Some(builtin_fns::call_np_linalg_matrix_power( let [out] = perform_nalgebra_call(
generator, generator,
ctx, ctx,
(x1_ty, x1_val), [x1, x2],
(x2_ty, x2_val), [2],
)?)) |ctx, [x1, x2], [out]| {
call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name()));
},
);
Ok(Some(out.instance.value.as_basic_value_enum()))
}), }),
), ),
PrimDef::FunNpLinalgDet => create_fn_by_codegen( PrimDef::FunNpLinalgDet => create_fn_by_codegen(
@ -2357,7 +2418,22 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?)) let x1 = AnyObject { value: x1_val, ty: x1_ty };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
// The output is returned as a 1D ndarray, even though the result is a single float.
// It is implemented like this at the moment because it is convenient.
// TODO: Don't do that.
let [out] =
perform_nalgebra_call(generator, ctx, [x1], [1], |ctx, [x1], [out]| {
call_np_linalg_det(ctx, x1, out, Some(prim.name()));
});
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let determinant = out.get_nth(generator, ctx, zero);
Ok(Some(determinant.value))
}), }),
), ),
_ => unreachable!(), _ => unreachable!(),

View File

@ -1,12 +1,11 @@
use crate::{ use crate::{
symbol_resolver::SymbolValue,
toplevel::helper::PrimDef, toplevel::helper::PrimDef,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
}, },
}; };
use itertools::{Either, Itertools}; use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments. /// Creates a `ndarray` [`Type`] with the given type arguments.
/// ///