forked from M-Labs/nac3
WIP: core/ndstrides: done
This commit is contained in:
parent
2fbe981701
commit
fd78f7a0e8
@ -1,17 +1,12 @@
|
||||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
|
||||
use inkwell::values::{BasicValueEnum, IntValue};
|
||||
use inkwell::IntPredicate;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{
|
||||
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor,
|
||||
};
|
||||
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, RangeValue, TypedArrayLikeAccessor};
|
||||
use crate::codegen::expr::destructure_range;
|
||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
||||
|
||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||
@ -84,505 +79,3 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
|
||||
fn build_output_struct<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
out_matrices: Vec<BasicValueEnum<'ctx>>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let field_ty =
|
||||
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
|
||||
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||
|
||||
for (i, v) in out_matrices.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
out_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
}
|
||||
}
|
||||
out_ptr
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_cholesky` linalg function
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_cholesky";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_qr` linalg function
|
||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_qr";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_svd` linalg function
|
||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_svd";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_inv` linalg function
|
||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_inv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_pinv` linalg function
|
||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_pinv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_lu` linalg function
|
||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_lu";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
todo!();
|
||||
|
||||
/*
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let n2_array = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
n2_array.data().set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
n2.as_basic_value_enum(),
|
||||
);
|
||||
};
|
||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_det` linalg function
|
||||
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let out = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||
let res =
|
||||
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||
Ok(res)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_schur` linalg function
|
||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_schur";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_hessenberg` linalg function
|
||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_load(out_ptr, "Hessenberg_decomposition_result")
|
||||
.map(Into::into)
|
||||
.unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
@ -32,10 +32,7 @@ use crate::{
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{
|
||||
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
|
||||
StructValue,
|
||||
},
|
||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::{chain, izip, Either, Itertools};
|
||||
@ -314,7 +311,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
self.raise_exn(
|
||||
generator,
|
||||
"0:NotImplementedError",
|
||||
msg.into(),
|
||||
msg,
|
||||
[None, None, None],
|
||||
self.current_loc,
|
||||
);
|
||||
@ -639,7 +636,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
params.map(|p| p.map(|p| param_model.check_value(generator, self.ctx, p).unwrap()));
|
||||
|
||||
let err_msg = self.gen_string(generator, err_msg);
|
||||
self.make_assert_impl(generator, cond, err_name, err_msg.into(), params, loc);
|
||||
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
|
||||
}
|
||||
|
||||
pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
|
||||
@ -1574,9 +1571,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(left.dtype), left.instance),
|
||||
(&Some(left.dtype), left.value),
|
||||
op,
|
||||
(&Some(right.dtype), right.instance),
|
||||
(&Some(right.dtype), right.value),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
@ -2689,7 +2686,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
ctx.raise_exn(
|
||||
generator,
|
||||
"0:UnwrapNoneError",
|
||||
err_msg.into(),
|
||||
err_msg,
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
@ -2,7 +2,6 @@ use crate::symbol_resolver::SymbolResolver;
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
mod test;
|
||||
pub mod util;
|
||||
|
||||
use super::model::*;
|
||||
use super::object::ndarray::broadcast::ShapeEntry;
|
||||
@ -17,6 +16,7 @@ use super::{
|
||||
};
|
||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use function::{get_sizet_dependent_function_name, CallFunction};
|
||||
use inkwell::values::BasicValue;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
@ -29,8 +29,6 @@ use inkwell::{
|
||||
};
|
||||
use itertools::Either;
|
||||
use nac3parser::ast::Expr;
|
||||
use util::function::CallFunction;
|
||||
use util::get_sizet_dependent_function_name;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt(ctx: &Context) -> Module {
|
||||
|
@ -1,107 +0,0 @@
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
|
||||
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
|
||||
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
|
||||
#[must_use]
|
||||
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
name: &str,
|
||||
) -> String {
|
||||
let mut name = name.to_owned();
|
||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => {}
|
||||
64 => name.push_str("64"),
|
||||
bit_width => {
|
||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
|
||||
pub mod function {
|
||||
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||
use inkwell::{
|
||||
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
|
||||
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Arg<'ctx> {
|
||||
ty: BasicMetadataTypeEnum<'ctx>,
|
||||
val: BasicMetadataValueEnum<'ctx>,
|
||||
}
|
||||
|
||||
/// Helper structure to reduce IRRT Inkwell function call boilerplate
|
||||
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
|
||||
generator: &'d mut G,
|
||||
ctx: &'b CodeGenContext<'ctx, 'a>,
|
||||
/// Function name
|
||||
name: &'c str,
|
||||
/// Call arguments
|
||||
args: Vec<Arg<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
|
||||
pub fn begin(
|
||||
generator: &'d mut G,
|
||||
ctx: &'b CodeGenContext<'ctx, 'a>,
|
||||
name: &'c str,
|
||||
) -> Self {
|
||||
CallFunction { generator, ctx, name, args: Vec::new() }
|
||||
}
|
||||
|
||||
/// Push a call argument to the function call.
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
#[must_use]
|
||||
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
|
||||
let arg = Arg {
|
||||
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
|
||||
val: arg.value.as_basic_value_enum().into(),
|
||||
};
|
||||
self.args.push(arg);
|
||||
self
|
||||
}
|
||||
|
||||
/// Call the function and expect the function to return a value of type of `return_model`.
|
||||
#[must_use]
|
||||
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
|
||||
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
|
||||
|
||||
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
|
||||
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
|
||||
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
|
||||
ret
|
||||
}
|
||||
|
||||
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
|
||||
#[must_use]
|
||||
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
|
||||
self.returning(name, M::default())
|
||||
}
|
||||
|
||||
/// Call the function and expect the function to return a void-type.
|
||||
pub fn returning_void(self) {
|
||||
let ret_ty = self.ctx.ctx.void_type();
|
||||
|
||||
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
|
||||
}
|
||||
|
||||
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
|
||||
where
|
||||
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
|
||||
{
|
||||
// Get the LLVM function, declare the function if it doesn't exist - it will be defined by other
|
||||
// components of NAC3.
|
||||
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
|
||||
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
|
||||
let fn_type = make_fn_type(&tys);
|
||||
self.ctx.module.add_function(self.name, fn_type, None)
|
||||
});
|
||||
|
||||
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
|
||||
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
@ -21,6 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
||||
type Type: BasicType<'ctx>;
|
||||
|
||||
/// Return the [`BasicType`] of this model.
|
||||
#[must_use]
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
|
||||
|
||||
/// Check if a [`BasicType`] is the same type of this model.
|
||||
|
88
nac3core/src/codegen/model/float.rs
Normal file
88
nac3core/src/codegen/model/float.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use std::fmt;
|
||||
|
||||
use inkwell::{context::Context, types::FloatType, values::FloatValue};
|
||||
|
||||
use crate::codegen::CodeGenerator;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
|
||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> FloatType<'ctx>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Float32;
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Float64;
|
||||
|
||||
impl<'ctx> FloatKind<'ctx> for Float32 {
|
||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> FloatType<'ctx> {
|
||||
ctx.f32_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> FloatKind<'ctx> for Float64 {
|
||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> FloatType<'ctx> {
|
||||
ctx.f64_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
|
||||
|
||||
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
|
||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_generator: &G,
|
||||
_ctx: &'ctx Context,
|
||||
) -> FloatType<'ctx> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct FloatModel<N>(pub N);
|
||||
pub type Float<'ctx, N> = Instance<'ctx, FloatModel<N>>;
|
||||
|
||||
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for FloatModel<N> {
|
||||
type Value = FloatValue<'ctx>;
|
||||
type Type = FloatType<'ctx>;
|
||||
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||
self.0.get_float_type(generator, ctx)
|
||||
}
|
||||
|
||||
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &'ctx Context,
|
||||
ty: T,
|
||||
) -> Result<(), ModelError> {
|
||||
let ty = ty.as_basic_type_enum();
|
||||
let Ok(ty) = FloatType::try_from(ty) else {
|
||||
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
|
||||
};
|
||||
|
||||
let exp_ty = self.0.get_float_type(generator, ctx);
|
||||
|
||||
// TODO: Inkwell does not have get_bit_width for FloatType?
|
||||
// TODO: Quick hack for now, but does this actually work?
|
||||
if ty != exp_ty {
|
||||
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
125
nac3core/src/codegen/model/function.rs
Normal file
125
nac3core/src/codegen/model/function.rs
Normal file
@ -0,0 +1,125 @@
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
|
||||
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
|
||||
use super::*;
|
||||
|
||||
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
|
||||
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
|
||||
#[must_use]
|
||||
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
name: &str,
|
||||
) -> String {
|
||||
let mut name = name.to_owned();
|
||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => {}
|
||||
64 => name.push_str("64"),
|
||||
bit_width => {
|
||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Arg<'ctx> {
|
||||
ty: BasicMetadataTypeEnum<'ctx>,
|
||||
val: BasicMetadataValueEnum<'ctx>,
|
||||
}
|
||||
|
||||
/// A structure to construct & call an LLVM function.
|
||||
///
|
||||
/// This is a helper to reduce IRRT Inkwell function call boilerplate
|
||||
// TODO: Remove the lifetimes somehow? There is 4 of them.
|
||||
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
|
||||
generator: &'d mut G,
|
||||
ctx: &'b CodeGenContext<'ctx, 'a>,
|
||||
/// Function name
|
||||
name: &'c str,
|
||||
/// Call arguments
|
||||
args: Vec<Arg<'ctx>>,
|
||||
/// LLVM function Attributes
|
||||
attrs: Vec<&'static str>,
|
||||
}
|
||||
|
||||
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
|
||||
pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
|
||||
CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
|
||||
}
|
||||
|
||||
/// Push a list of LLVM function attributes to the function declaration.
|
||||
#[must_use]
|
||||
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
|
||||
self.attrs = attrs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Push a call argument to the function call.
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
#[must_use]
|
||||
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
|
||||
let arg = Arg {
|
||||
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
|
||||
val: arg.value.as_basic_value_enum().into(),
|
||||
};
|
||||
self.args.push(arg);
|
||||
self
|
||||
}
|
||||
|
||||
/// Call the function and expect the function to return a value of type of `return_model`.
|
||||
#[must_use]
|
||||
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
|
||||
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
|
||||
|
||||
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
|
||||
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
|
||||
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
|
||||
ret
|
||||
}
|
||||
|
||||
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
|
||||
#[must_use]
|
||||
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
|
||||
self.returning(name, M::default())
|
||||
}
|
||||
|
||||
/// Call the function and expect the function to return a void-type.
|
||||
pub fn returning_void(self) {
|
||||
let ret_ty = self.ctx.ctx.void_type();
|
||||
|
||||
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
|
||||
}
|
||||
|
||||
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
|
||||
where
|
||||
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
|
||||
{
|
||||
// Get the LLVM function.
|
||||
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
|
||||
// Declare the function if it doesn't exist.
|
||||
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
|
||||
|
||||
let func_type = make_fn_type(&tys);
|
||||
let func = self.ctx.module.add_function(self.name, func_type, None);
|
||||
|
||||
for attr in &self.attrs {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
|
||||
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
|
||||
}
|
||||
}
|
@ -96,7 +96,6 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
|
||||
type Value = IntValue<'ctx>;
|
||||
type Type = IntType<'ctx>;
|
||||
|
||||
#[must_use]
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||
self.0.get_int_type(generator, ctx)
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
mod any;
|
||||
mod core;
|
||||
mod float;
|
||||
pub mod function;
|
||||
mod int;
|
||||
mod ptr;
|
||||
mod structure;
|
||||
@ -7,6 +9,7 @@ pub mod util;
|
||||
|
||||
pub use any::*;
|
||||
pub use core::*;
|
||||
pub use float::*;
|
||||
pub use int::*;
|
||||
pub use ptr::*;
|
||||
pub use structure::*;
|
||||
|
@ -89,15 +89,17 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, Element> {
|
||||
self.model.check_value(generator, ctx.ctx, new_ptr).unwrap()
|
||||
}
|
||||
|
||||
// Load the `i`-th element (0-based) on the array with [`inkwell::builder::Builder::build_in_bounds_gep`].
|
||||
pub fn ix<G: CodeGenerator + ?Sized>(
|
||||
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
|
||||
#[must_use]
|
||||
pub fn offset_const<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
i: IntValue<'ctx>,
|
||||
offset: u64,
|
||||
name: &str,
|
||||
) -> Instance<'ctx, Element> {
|
||||
self.offset(generator, ctx, i, name).load(generator, ctx, name)
|
||||
) -> Ptr<'ctx, Element> {
|
||||
let offset = ctx.ctx.i32_type().const_int(offset, false);
|
||||
self.offset(generator, ctx, offset, name)
|
||||
}
|
||||
|
||||
/// Load the value with [`inkwell::builder::Builder::build_load`].
|
||||
|
@ -94,8 +94,7 @@ where
|
||||
{
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
let ndarray =
|
||||
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
|
||||
let ndarray = NDArrayObject::alloca_ndarray_type(generator, ctx, ndarray_ty, "ndarray");
|
||||
|
||||
// Validate `shape`
|
||||
let ndims = ndarray.get_ndims(generator, ctx.ctx);
|
||||
@ -321,7 +320,7 @@ pub fn gen_ndarray_arange<'ctx>(
|
||||
let input = sizet_model.s_extend_or_bit_cast(generator, ctx, input, "input_dim");
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
let ndarray = NDArrayObject::alloca_uninitialized(
|
||||
let ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
@ -385,20 +384,17 @@ pub fn gen_ndarray_shape<'ctx>(
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Process ndarray
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||
|
||||
for i in 0..ndarray.ndims {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||
let dim = ndarray
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.shape, "")
|
||||
.ix(generator, ctx, i.value, "dim");
|
||||
.offset_const(generator, ctx, i, "")
|
||||
.load(generator, ctx, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
objects
|
||||
@ -427,20 +423,17 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Process ndarray
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||
|
||||
for i in 0..ndarray.ndims {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||
let dim = ndarray
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.strides, "")
|
||||
.ix(generator, ctx, i.value, "dim");
|
||||
.offset_const(generator, ctx, i, "")
|
||||
.load(generator, ctx, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
objects
|
||||
@ -524,7 +517,7 @@ pub fn gen_ndarray_array<'ctx>(
|
||||
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
||||
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let output_ndims = extract_ndims(&mut ctx.unifier, ndims);
|
||||
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||
|
@ -1,5 +1,3 @@
|
||||
use inkwell::values::BasicValue;
|
||||
|
||||
use crate::{
|
||||
codegen::{model::*, structure::List, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
||||
|
@ -41,13 +41,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape");
|
||||
call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape);
|
||||
|
||||
let ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
ndims_int,
|
||||
"ndarray_from_list",
|
||||
);
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int, "ndarray_from_list");
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||
ndarray.create_data(generator, ctx);
|
||||
|
||||
@ -73,13 +67,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
|
||||
if ndims == 1 {
|
||||
// `list` is not nested, does not need to copy.
|
||||
let ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
1,
|
||||
"ndarray_from_list_no_copy",
|
||||
);
|
||||
let ndarray =
|
||||
NDArrayObject::alloca(generator, ctx, dtype, 1, "ndarray_from_list_no_copy");
|
||||
|
||||
// Set data
|
||||
let data = list.get_opaque_items_ptr(generator, ctx);
|
||||
|
@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
target_ndims: u64,
|
||||
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let broadcast_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
let broadcast_ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
self.dtype,
|
||||
|
@ -80,10 +80,10 @@ where
|
||||
|
||||
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
|
||||
// Special handling for floats
|
||||
let n = scalar.instance.into_float_value();
|
||||
let n = scalar.value.into_float_value();
|
||||
handle_float(generator, ctx, n)
|
||||
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
|
||||
let n = scalar.instance.into_int_value();
|
||||
let n = scalar.value.into_int_value();
|
||||
|
||||
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
|
||||
@ -95,7 +95,7 @@ where
|
||||
};
|
||||
|
||||
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
ScalarObject { instance: result.into(), dtype: ret_int_dtype }
|
||||
ScalarObject { value: result.into(), dtype: ret_int_dtype }
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
@ -104,7 +104,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.instance.into_float_value() // self.value must be a FloatValue
|
||||
self.value.into_float_value() // self.value must be a FloatValue
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
@ -115,7 +115,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
|
||||
let value = self.instance.into_int_value();
|
||||
let value = self.value.into_int_value();
|
||||
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
|
||||
value
|
||||
} else {
|
||||
@ -142,12 +142,12 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
|
||||
let common_ty = lhs.dtype;
|
||||
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
|
||||
let lhs = lhs.instance.into_float_value();
|
||||
let rhs = rhs.instance.into_float_value();
|
||||
let lhs = lhs.value.into_float_value();
|
||||
let rhs = rhs.value.into_float_value();
|
||||
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
|
||||
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
||||
let lhs = lhs.instance.into_int_value();
|
||||
let rhs = rhs.instance.into_int_value();
|
||||
let lhs = lhs.value.into_int_value();
|
||||
let rhs = rhs.value.into_int_value();
|
||||
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
||||
@ -266,14 +266,14 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
|
||||
self.instance.into_int_value()
|
||||
self.value.into_int_value()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.instance.into_int_value();
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder
|
||||
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
@ -281,7 +281,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
|
||||
ScalarObject { dtype: ctx.primitives.bool, instance: result.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `float()`.
|
||||
@ -290,21 +290,21 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.instance.into_float_value()
|
||||
self.value.into_float_value()
|
||||
} else if ctx
|
||||
.unifier
|
||||
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
|
||||
{
|
||||
let n = self.instance.into_int_value();
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
|
||||
let n = self.instance.into_int_value();
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype]);
|
||||
};
|
||||
|
||||
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `round()`.
|
||||
@ -318,13 +318,13 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype, ret_int_dtype])
|
||||
};
|
||||
ScalarObject { dtype: ret_int_dtype, instance: result.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_round()`.
|
||||
@ -333,12 +333,12 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
#[must_use]
|
||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
llvm_intrinsics::call_float_rint(ctx, n, None)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
ScalarObject { dtype: ctx.primitives.float, instance: result.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `min()` or `max()`.
|
||||
@ -360,8 +360,8 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
||||
};
|
||||
let result =
|
||||
function(ctx, a.instance.into_float_value(), b.instance.into_float_value(), None);
|
||||
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(
|
||||
common_dtype,
|
||||
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
|
||||
@ -371,9 +371,8 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
||||
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
||||
};
|
||||
let result =
|
||||
function(ctx, a.instance.into_int_value(), b.instance.into_int_value(), None);
|
||||
ScalarObject { instance: result.as_basic_value_enum(), dtype: common_dtype }
|
||||
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [common_dtype])
|
||||
}
|
||||
@ -399,11 +398,11 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
|
||||
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
|
||||
ScalarObject { dtype: ret_int_dtype, instance: n.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
@ -419,9 +418,9 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
ScalarObject { dtype: ctx.primitives.float, instance: n.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
@ -431,16 +430,16 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
#[must_use]
|
||||
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.instance.into_float_value();
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
||||
ScalarObject { instance: n.into(), dtype: ctx.primitives.float }
|
||||
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.instance.into_int_value();
|
||||
let n = self.value.into_int_value();
|
||||
|
||||
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
||||
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
||||
|
||||
ScalarObject { instance: n.into(), dtype: self.dtype }
|
||||
ScalarObject { value: n.into(), dtype: self.dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
@ -482,7 +481,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
pextremum_index.store(ctx, zero);
|
||||
|
||||
let first_scalar = self.get_nth(generator, ctx, zero);
|
||||
ctx.builder.build_store(pextremum, first_scalar.instance).unwrap();
|
||||
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
|
||||
|
||||
// Find extremum
|
||||
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
|
||||
@ -495,7 +494,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let scalar = self.get_nth(generator, ctx, i);
|
||||
|
||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, instance: old_extremum };
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||
|
||||
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
|
||||
|
||||
@ -523,7 +522,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
|
||||
|
||||
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
|
||||
let extremum = ScalarObject { dtype: self.dtype, instance: extremum };
|
||||
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
|
||||
|
||||
(extremum, extremum_index)
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
|
||||
|
||||
use super::{scalar::ScalarOrNDArray, NDArrayObject};
|
||||
use super::NDArrayObject;
|
||||
|
||||
pub type NDIndexType = Byte;
|
||||
|
||||
@ -215,8 +215,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let dst_ndims = self.deduce_ndims_after_indexing_with(indexes);
|
||||
let dst_ndarray =
|
||||
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, dst_ndims, name);
|
||||
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims, name);
|
||||
|
||||
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes);
|
||||
call_nac3_ndarray_index(
|
||||
|
@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let out_ndarray = match out {
|
||||
NDArrayOut::NewNDArray { dtype } => {
|
||||
// Create a new ndarray based on the broadcast shape.
|
||||
let result_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
let result_ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
@ -137,7 +137,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
if let Some(scalars) = all_scalars {
|
||||
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
|
||||
let scalar =
|
||||
ScalarObject { instance: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
|
||||
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
|
@ -22,7 +22,10 @@ use crate::{
|
||||
structure::{NDArray, SimpleNDArray},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||
toplevel::{
|
||||
helper::{create_ndims, extract_ndims},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
use indexing::RustNDIndex;
|
||||
@ -71,6 +74,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
NDArrayObject { dtype, ndims, instance: value }
|
||||
}
|
||||
|
||||
/// Forget that this is an ndarray and convert to an [`AnyObject`].
|
||||
pub fn to_any_object(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||
let ty = self.get_ndarray_type(ctx);
|
||||
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
|
||||
}
|
||||
|
||||
/// Create a [`SimpleNDArray`] from the contents of this ndarray.
|
||||
///
|
||||
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
|
||||
@ -88,6 +97,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
item_model: Item,
|
||||
name: &str,
|
||||
) -> Ptr<'ctx, StructModel<SimpleNDArray<Item>>> {
|
||||
// Sanity check on `self.dtype` and `item_model`.
|
||||
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
||||
@ -101,7 +111,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||
|
||||
// Allocate and setup the resulting [`SimpleNDArray`].
|
||||
let result = simple_ndarray_model.alloca(generator, ctx, "simple_ndarray");
|
||||
let result = simple_ndarray_model.alloca(generator, ctx, name);
|
||||
|
||||
// Set ndims and shape.
|
||||
let ndims = self.get_ndims(generator, ctx.ctx);
|
||||
@ -155,13 +165,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// TODO: Check if `ndims` is consistent with that in `simple_array`?
|
||||
|
||||
// Allocate the resulting ndarray.
|
||||
let ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
ndims,
|
||||
"from_simple_ndarray",
|
||||
);
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "from_simple_ndarray");
|
||||
|
||||
// Set data, shape by simply copying addresses.
|
||||
let data = simple_ndarray
|
||||
@ -178,6 +182,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Get the typechecker ndarray type of this [`NDArrayObject`].
|
||||
pub fn get_ndarray_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
|
||||
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
|
||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
|
||||
}
|
||||
|
||||
/// Get the `np.size()` of this ndarray.
|
||||
pub fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
@ -243,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
) -> ScalarObject<'ctx> {
|
||||
let p = self.get_nth_pointer(generator, ctx, nth, "value");
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
ScalarObject { dtype: self.dtype, instance: value }
|
||||
ScalarObject { dtype: self.dtype, value }
|
||||
}
|
||||
|
||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
||||
@ -283,7 +293,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// - `ndims`: set to the value of `ndims`.
|
||||
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
||||
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
||||
pub fn alloca_uninitialized<G: CodeGenerator + ?Sized>(
|
||||
pub fn alloca<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
@ -318,7 +328,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
|
||||
/// Convenience function.
|
||||
/// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray.
|
||||
pub fn alloca_uninitialized_of_type<G: CodeGenerator + ?Sized>(
|
||||
pub fn alloca_ndarray_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ty: Type,
|
||||
@ -326,11 +336,34 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
) -> Self {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
|
||||
Self::alloca(generator, ctx, dtype, ndims, name)
|
||||
}
|
||||
|
||||
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents
|
||||
/// over.
|
||||
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
||||
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &[u64],
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
|
||||
|
||||
// Write shape
|
||||
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
let dim = sizet_model.constant(generator, ctx.ctx, *dim);
|
||||
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, dim);
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
|
||||
///
|
||||
/// The new ndarray will own its data and will be C-contiguous.
|
||||
#[must_use]
|
||||
@ -340,8 +373,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let clone =
|
||||
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
|
||||
let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, name);
|
||||
|
||||
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
|
||||
clone.copy_shape_from_array(generator, ctx, shape);
|
||||
@ -519,7 +551,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
{
|
||||
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
let scalar = ScalarObject { dtype: self.dtype, instance: value };
|
||||
let scalar = ScalarObject { dtype: self.dtype, value };
|
||||
body(generator, ctx, hooks, i, scalar)
|
||||
})
|
||||
}
|
||||
@ -588,13 +620,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||
|
||||
let dst_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
self.dtype,
|
||||
new_ndims,
|
||||
"reshaped_ndarray",
|
||||
);
|
||||
let dst_ndarray =
|
||||
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
|
||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
||||
|
||||
let size = self.size(generator, ctx);
|
||||
@ -661,13 +688,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let transposed_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
self.dtype,
|
||||
self.ndims,
|
||||
"transposed_ndarray",
|
||||
);
|
||||
let transposed_ndarray =
|
||||
NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, "transposed_ndarray");
|
||||
|
||||
let num_axes = self.get_ndims(generator, ctx.ctx);
|
||||
|
||||
@ -686,7 +708,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
transposed_ndarray
|
||||
}
|
||||
|
||||
/// Check if this NDArray can be used as an `out` ndarray for an operation.
|
||||
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
|
||||
///
|
||||
/// Raise an exception if the shapes do not match.
|
||||
pub fn check_can_be_written_by_out<G: CodeGenerator + ?Sized>(
|
||||
|
@ -1 +1,53 @@
|
||||
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||
|
||||
use crate::codegen::{model::*, structure::SimpleNDArray, CodeGenContext, CodeGenerator};
|
||||
|
||||
use super::NDArrayObject;
|
||||
|
||||
pub fn perform_nalgebra_call<'ctx, 'a, const NUM_INPUTS: usize, const NUM_OUTPUTS: usize, G, F>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: [NDArrayObject<'ctx>; NUM_INPUTS],
|
||||
output_ndims: [u64; NUM_OUTPUTS],
|
||||
invoke_function: F,
|
||||
) -> [NDArrayObject<'ctx>; NUM_OUTPUTS]
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
[BasicValueEnum<'ctx>; NUM_INPUTS],
|
||||
[BasicValueEnum<'ctx>; NUM_OUTPUTS],
|
||||
),
|
||||
{
|
||||
// TODO: Allow stacked inputs. See NumPy docs.
|
||||
|
||||
let f64_model = FloatModel(Float64);
|
||||
let simple_ndarray_model = StructModel(SimpleNDArray { item: f64_model });
|
||||
|
||||
// Prepare inputs & outputs and invoke
|
||||
let inputs = inputs.map(|input| {
|
||||
// Sanity check. Typechecker ensures this.
|
||||
assert!(ctx.unifier.unioned(input.dtype, ctx.primitives.float));
|
||||
|
||||
input
|
||||
.make_simple_ndarray(generator, ctx, FloatModel(Float64), "nalgebra_input")
|
||||
.value
|
||||
.as_basic_value_enum()
|
||||
});
|
||||
let outputs = [simple_ndarray_model.alloca(generator, ctx, "nalgebra_output"); NUM_OUTPUTS];
|
||||
invoke_function(ctx, inputs, outputs.map(|output| output.value.as_basic_value_enum()));
|
||||
|
||||
// Turn the outputs into strided NDArrays
|
||||
let mut output_i = 0;
|
||||
outputs.map(|output| {
|
||||
let out = NDArrayObject::from_simple_ndarray(
|
||||
generator,
|
||||
ctx,
|
||||
output,
|
||||
ctx.primitives.float,
|
||||
output_ndims[output_i],
|
||||
);
|
||||
output_i += 1;
|
||||
out
|
||||
})
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
|
||||
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
|
||||
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
|
||||
let dst = NDArrayObject::alloca_uninitialized(
|
||||
let dst = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|
@ -15,7 +15,7 @@ use super::NDArrayObject;
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ScalarObject<'ctx> {
|
||||
pub dtype: Type,
|
||||
pub instance: BasicValueEnum<'ctx>,
|
||||
pub value: BasicValueEnum<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
@ -31,12 +31,11 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
let pbyte_model = PtrModel(IntModel(Byte));
|
||||
|
||||
// We have to put the value on the stack to get a data pointer.
|
||||
let data = ctx.builder.build_alloca(self.instance.get_type(), "as_ndarray_scalar").unwrap();
|
||||
ctx.builder.build_store(data, self.instance).unwrap();
|
||||
let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap();
|
||||
ctx.builder.build_store(data, self.value).unwrap();
|
||||
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
|
||||
|
||||
let ndarray =
|
||||
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray");
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray");
|
||||
ndarray.instance.set(ctx, |f| f.data, data);
|
||||
ndarray
|
||||
}
|
||||
@ -54,7 +53,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
#[must_use]
|
||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.instance,
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
|
||||
}
|
||||
}
|
||||
@ -135,7 +134,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ScalarOrNDArray::NDArray(ndarray)
|
||||
}
|
||||
_ => {
|
||||
let scalar = ScalarObject { dtype: object.ty, instance: object.value };
|
||||
let scalar = ScalarObject { dtype: object.ty, value: object.value };
|
||||
ScalarOrNDArray::Scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
@ -48,8 +48,9 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// Load the i-th int32 in the input sequence
|
||||
let int = input_sequence
|
||||
.instance
|
||||
.get(generator, ctx, |f| f.items, "int")
|
||||
.ix(generator, ctx, i.value, "int")
|
||||
.get(generator, ctx, |f| f.items, "")
|
||||
.offset(generator, ctx, i.value, "")
|
||||
.load(generator, ctx, "")
|
||||
.value
|
||||
.into_int_value();
|
||||
|
||||
@ -65,7 +66,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
(len, result)
|
||||
}
|
||||
TypeEnum::TTuple { ty: tuple_types, .. } => {
|
||||
TypeEnum::TTuple { .. } => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
|
||||
let input_sequence = TupleObject::from_object(ctx, input_sequence);
|
||||
|
@ -34,13 +34,13 @@ impl<'ctx> TupleObject<'ctx> {
|
||||
};
|
||||
|
||||
let value = object.value.into_struct_value();
|
||||
if value.get_type().count_fields() as usize != tys.len() {
|
||||
panic!(
|
||||
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
||||
tys.len(),
|
||||
value.get_type().count_fields()
|
||||
);
|
||||
}
|
||||
let value_num_fields = value.get_type().count_fields() as usize;
|
||||
assert!(
|
||||
value_num_fields != tys.len(),
|
||||
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
||||
tys.len(),
|
||||
value_num_fields
|
||||
);
|
||||
|
||||
TupleObject { tys: tys.clone(), value }
|
||||
}
|
||||
@ -74,18 +74,23 @@ impl<'ctx> TupleObject<'ctx> {
|
||||
/// Get the `len()` of this tuple.
|
||||
///
|
||||
/// We statically know the lengths of tuples in NAC3.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
self.tys.len()
|
||||
}
|
||||
|
||||
/// Check if this tuple is an empty/unit tuple.
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Get the `i`-th (0-based) object in this tuple.
|
||||
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
|
||||
if i >= self.len() {
|
||||
panic!("Tuple object with length {} have index {i}", self.len());
|
||||
}
|
||||
assert!(i >= self.len(), "Tuple object with length {} have index {i}", self.len());
|
||||
|
||||
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
|
||||
let ty = self.tys[i];
|
||||
AnyObject { value, ty }
|
||||
AnyObject { ty, value }
|
||||
}
|
||||
}
|
||||
|
@ -1827,7 +1827,7 @@ pub fn gen_stmt<G: CodeGenerator>(
|
||||
let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?;
|
||||
cslice_model.check_value(generator, ctx.ctx, msg).unwrap()
|
||||
}
|
||||
None => ctx.gen_string(generator, "").into(),
|
||||
None => ctx.gen_string(generator, ""),
|
||||
};
|
||||
|
||||
ctx.make_assert_impl(
|
||||
|
@ -15,14 +15,19 @@ use crate::{
|
||||
codegen::{
|
||||
builtin_fns,
|
||||
classes::{ProxyValue, RangeValue},
|
||||
extern_fns, irrt, llvm_intrinsics,
|
||||
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
|
||||
irrt, llvm_intrinsics,
|
||||
model::{IntModel, SizeT},
|
||||
numpy::*,
|
||||
numpy_new::{self, gen_ndarray_transpose},
|
||||
object::{
|
||||
ndarray::{
|
||||
functions::{FloorOrCeil, MinOrMax},
|
||||
nalgebra::perform_nalgebra_call,
|
||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||
NDArrayObject,
|
||||
},
|
||||
tuple::TupleObject,
|
||||
AnyObject,
|
||||
},
|
||||
stmt::exn_constructor,
|
||||
@ -1109,7 +1114,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
PrimDef::FunBool => scalar.cast_to_bool(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(result.instance)
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1171,7 +1176,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx,
|
||||
ret_int_dtype,
|
||||
|generator, ctx, _i, scalar| {
|
||||
Ok(scalar.round(generator, ctx, ret_int_dtype).instance)
|
||||
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1237,7 +1242,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx,
|
||||
int_sized,
|
||||
|generator, ctx, _i, scalar| {
|
||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).instance)
|
||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1638,7 +1643,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx.primitives.float,
|
||||
move |_generator, ctx, _i, scalar| {
|
||||
let result = scalar.np_floor_or_ceil(ctx, kind);
|
||||
Ok(result.instance)
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1667,7 +1672,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx.primitives.float,
|
||||
|_generator, ctx, _i, scalar| {
|
||||
let result = scalar.np_round(ctx);
|
||||
Ok(result.instance)
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1754,10 +1759,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let m = ScalarObject { dtype: m_ty, instance: m_val };
|
||||
let n = ScalarObject { dtype: n_ty, instance: n_val };
|
||||
let m = ScalarObject { dtype: m_ty, value: m_val };
|
||||
let n = ScalarObject { dtype: n_ty, value: n_val };
|
||||
let result = ScalarObject::min_or_max(ctx, kind, m, n);
|
||||
Ok(Some(result.instance))
|
||||
Ok(Some(result.value))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
@ -1811,10 +1816,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
.value
|
||||
.as_basic_value_enum(),
|
||||
PrimDef::FunNpMin => {
|
||||
a.min_or_max(generator, ctx, MinOrMax::Min).instance.as_basic_value_enum()
|
||||
a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum()
|
||||
}
|
||||
PrimDef::FunNpMax => {
|
||||
a.min_or_max(generator, ctx, MinOrMax::Max).instance.as_basic_value_enum()
|
||||
a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@ -1883,7 +1888,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let x2 = scalars[1];
|
||||
|
||||
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
|
||||
Ok(result.instance)
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1925,7 +1930,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
num_ty.ty,
|
||||
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).instance),
|
||||
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
@ -2253,6 +2258,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
),
|
||||
|
||||
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
|
||||
// Function type: NDArray[float; 2] -> NDArray[float; 2]
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
@ -2263,14 +2269,22 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { value: x1_val, ty: x1_ty };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
|
||||
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
|
||||
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
|
||||
PrimDef::FunNpLinalgCholesky => extern_fns::call_np_linalg_cholesky,
|
||||
PrimDef::FunNpLinalgInv => extern_fns::call_np_linalg_inv,
|
||||
PrimDef::FunNpLinalgPinv => extern_fns::call_np_linalg_pinv,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||
|
||||
let [out] =
|
||||
perform_nalgebra_call(generator, ctx, [x1], [2], |ctx, [x1], [out]| {
|
||||
func(ctx, x1, out, Some(prim.name()));
|
||||
});
|
||||
|
||||
Ok(Some(out.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
@ -2279,6 +2293,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
| PrimDef::FunSpLinalgLu
|
||||
| PrimDef::FunSpLinalgSchur
|
||||
| PrimDef::FunSpLinalgHessenberg => {
|
||||
// Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 2])
|
||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
|
||||
is_vararg_ctx: false,
|
||||
@ -2293,22 +2308,35 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { value: x1_val, ty: x1_ty };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
|
||||
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
|
||||
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
|
||||
PrimDef::FunSpLinalgHessenberg => {
|
||||
builtin_fns::call_sp_linalg_hessenberg
|
||||
}
|
||||
PrimDef::FunNpLinalgQr => extern_fns::call_np_linalg_qr,
|
||||
PrimDef::FunSpLinalgLu => extern_fns::call_sp_linalg_lu,
|
||||
PrimDef::FunSpLinalgSchur => extern_fns::call_sp_linalg_schur,
|
||||
PrimDef::FunSpLinalgHessenberg => extern_fns::call_sp_linalg_hessenberg,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||
|
||||
let out = perform_nalgebra_call(
|
||||
generator,
|
||||
ctx,
|
||||
[x1],
|
||||
[2, 2],
|
||||
|ctx, [x1], [out1, out2]| func(ctx, x1, out1, out2, Some(prim.name())),
|
||||
);
|
||||
|
||||
// Create the output tuple
|
||||
let out = out.map(|o| o.to_any_object(ctx));
|
||||
let out = TupleObject::create(generator, ctx, out, prim.name());
|
||||
Ok(Some(out.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
PrimDef::FunNpLinalgSvd => {
|
||||
// Function type: NDArray[float; 2] -> (NDArray[float; 2], NDArray[float; 1], NDArray[float; 2])
|
||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
|
||||
is_vararg_ctx: false,
|
||||
@ -2323,8 +2351,30 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
|
||||
let out = perform_nalgebra_call(
|
||||
generator,
|
||||
ctx,
|
||||
[x1],
|
||||
[2, 1, 2],
|
||||
|ctx, [x1], [out1, out2, out3]| {
|
||||
extern_fns::call_np_linalg_svd(
|
||||
ctx,
|
||||
x1,
|
||||
out1,
|
||||
out2,
|
||||
out3,
|
||||
Some(prim.name()),
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
// Create the output tuple
|
||||
let out = out.map(|o| o.to_any_object(ctx));
|
||||
let out = TupleObject::create(generator, ctx, out, prim.name());
|
||||
Ok(Some(out.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
@ -2337,15 +2387,26 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
// The second argument is converted to an ndarray for implementation convenience.
|
||||
// TODO: Don't do that.
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
let x2 = ScalarObject { dtype: x2_ty, value: x2_val };
|
||||
let x2 = x2.as_ndarray(generator, ctx);
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
|
||||
let [out] = perform_nalgebra_call(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
[x1, x2],
|
||||
[2],
|
||||
|ctx, [x1, x2], [out]| {
|
||||
call_np_linalg_matrix_power(ctx, x1, x2, out, Some(prim.name()));
|
||||
},
|
||||
);
|
||||
Ok(Some(out.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
|
||||
@ -2357,7 +2418,22 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
|
||||
let x1 = AnyObject { value: x1_val, ty: x1_ty };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
// The output is returned as a 1D ndarray, even though the result is a single float.
|
||||
// It is implemented like this at the moment because it is convenient.
|
||||
// TODO: Don't do that.
|
||||
let [out] =
|
||||
perform_nalgebra_call(generator, ctx, [x1], [1], |ctx, [x1], [out]| {
|
||||
call_np_linalg_det(ctx, x1, out, Some(prim.name()));
|
||||
});
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
let determinant = out.get_nth(generator, ctx, zero);
|
||||
|
||||
Ok(Some(determinant.value))
|
||||
}),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
|
@ -1,12 +1,11 @@
|
||||
use crate::{
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use itertools::{Either, Itertools};
|
||||
use itertools::Itertools;
|
||||
|
||||
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
||||
///
|
||||
|
Loading…
Reference in New Issue
Block a user