forked from M-Labs/nac3
[core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
This commit is contained in:
parent
f7e296da53
commit
c58ce9c3a9
@ -461,8 +461,7 @@ fn format_rpc_arg<'ctx>(
|
|||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||||
let llvm_arg =
|
let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None);
|
||||||
NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None);
|
|
||||||
|
|
||||||
let llvm_usize_sizeof = ctx
|
let llvm_usize_sizeof = ctx
|
||||||
.builder
|
.builder
|
||||||
@ -1369,12 +1368,17 @@ fn polymorphic_print<'ctx>(
|
|||||||
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
fmt.push_str("array([");
|
fmt.push_str("array([");
|
||||||
flush(ctx, generator, &mut fmt, &mut args);
|
flush(ctx, generator, &mut fmt, &mut args);
|
||||||
|
|
||||||
let val =
|
let val = NDArrayValue::from_pointer_value(
|
||||||
NDArrayValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
|
value.into_pointer_value(),
|
||||||
|
llvm_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
|
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
|
||||||
let last =
|
let last =
|
||||||
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||||
|
@ -21,7 +21,10 @@ use super::{
|
|||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys},
|
toplevel::{
|
||||||
|
helper::{arraylike_flatten_element_type, PrimDef},
|
||||||
|
numpy::unpack_ndarray_var_tys,
|
||||||
|
},
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -65,10 +68,15 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
|
let elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty);
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let arg =
|
let arg = NDArrayValue::from_pointer_value(
|
||||||
NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None);
|
arg.into_pointer_value(),
|
||||||
|
ctx.get_llvm_type(generator, elem_ty),
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let ndims = arg.dim_sizes().size(ctx, generator);
|
let ndims = arg.dim_sizes().size(ctx, generator);
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
@ -143,13 +151,14 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -205,13 +214,14 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.int64,
|
ctx.primitives.int64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -283,13 +293,14 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint32,
|
ctx.primitives.uint32,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -350,13 +361,14 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.uint64,
|
ctx.primitives.uint64,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -416,13 +428,14 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -462,13 +475,14 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -502,13 +516,14 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -567,13 +582,14 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| {
|
|generator, ctx, val| {
|
||||||
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
let elem = call_bool(generator, ctx, (elem_ty, val))?;
|
||||||
|
|
||||||
@ -621,13 +637,14 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -671,13 +688,14 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|
NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -806,8 +824,8 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -906,9 +924,9 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_pointer_value(n, llvm_usize, None);
|
let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx
|
let n_sz_eqz = ctx
|
||||||
@ -926,7 +944,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_elem_ty, None)?;
|
||||||
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
|
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
@ -1068,8 +1086,8 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1114,6 +1132,7 @@ where
|
|||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||||
|
let llvm_arg_elem_ty = ctx.get_llvm_type(generator, arg_elem_ty);
|
||||||
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
|
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
@ -1121,7 +1140,7 @@ where
|
|||||||
ctx,
|
ctx,
|
||||||
ret_elem_ty,
|
ret_elem_ty,
|
||||||
None,
|
None,
|
||||||
NDArrayValue::from_pointer_value(x, llvm_usize, None),
|
NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, elem_val| {
|
|generator, ctx, elem_val| {
|
||||||
helper_call_numpy_unary_elementwise(
|
helper_call_numpy_unary_elementwise(
|
||||||
generator,
|
generator,
|
||||||
@ -1508,8 +1527,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1575,8 +1594,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1642,8 +1661,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1709,8 +1728,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1765,8 +1784,8 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1832,8 +1851,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1899,8 +1918,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
None,
|
None,
|
||||||
(x1, !is_ndarray1),
|
(x1_ty, x1, !is_ndarray1),
|
||||||
(x2, !is_ndarray2),
|
(x2_ty, x2, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
},
|
},
|
||||||
@ -1960,7 +1979,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
@ -2002,7 +2021,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
@ -2052,7 +2071,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
@ -2107,7 +2126,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
@ -2149,7 +2168,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
@ -2192,7 +2211,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
@ -2245,7 +2264,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
let n2_array = numpy::create_ndarray_const_shape(
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
generator,
|
generator,
|
||||||
@ -2340,7 +2359,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
@ -2383,7 +2402,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
};
|
};
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let dim0 = unsafe {
|
let dim0 = unsafe {
|
||||||
n1.dim_sizes()
|
n1.dim_sizes()
|
||||||
|
@ -1564,10 +1564,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
let left_val =
|
let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1);
|
||||||
NDArrayValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None);
|
let llvm_ndarray_dtype2 = ctx.get_llvm_type(generator, ndarray_dtype2);
|
||||||
let right_val =
|
|
||||||
NDArrayValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None);
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
|
left_val.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype1,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let right_val = NDArrayValue::from_pointer_value(
|
||||||
|
right_val.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype2,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let res = if op.base == Operator::MatMult {
|
let res = if op.base == Operator::MatMult {
|
||||||
// MatMult is the only binop which is not an elementwise op
|
// MatMult is the only binop which is not an elementwise op
|
||||||
@ -1591,8 +1602,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
BinopVariant::Normal => None,
|
BinopVariant::Normal => None,
|
||||||
BinopVariant::AugAssign => Some(left_val),
|
BinopVariant::AugAssign => Some(left_val),
|
||||||
},
|
},
|
||||||
(left_val.as_base_value().into(), false),
|
(ty1, left_val.as_base_value().into(), false),
|
||||||
(right_val.as_base_value().into(), false),
|
(ty2, right_val.as_base_value().into(), false),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
gen_binop_expr_with_values(
|
gen_binop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
@ -1616,8 +1627,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) =
|
let (ndarray_dtype, _) =
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
||||||
|
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
let ndarray_val = NDArrayValue::from_pointer_value(
|
let ndarray_val = NDArrayValue::from_pointer_value(
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
@ -1629,8 +1642,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
BinopVariant::Normal => None,
|
BinopVariant::Normal => None,
|
||||||
BinopVariant::AugAssign => Some(ndarray_val),
|
BinopVariant::AugAssign => Some(ndarray_val),
|
||||||
},
|
},
|
||||||
(left_val, !is_ndarray1),
|
(ty1, left_val, !is_ndarray1),
|
||||||
(right_val, !is_ndarray2),
|
(ty2, right_val, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
gen_binop_expr_with_values(
|
gen_binop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
@ -1810,8 +1823,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
|
|
||||||
let val = NDArrayValue::from_pointer_value(val.into_pointer_value(), llvm_usize, None);
|
let val = NDArrayValue::from_pointer_value(
|
||||||
|
val.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
// passing it to the elementwise codegen function
|
// passing it to the elementwise codegen function
|
||||||
@ -1902,15 +1921,21 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
let left_val =
|
let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1);
|
||||||
NDArrayValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None);
|
|
||||||
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
|
lhs.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype1,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
None,
|
None,
|
||||||
(left_val.as_base_value().into(), false),
|
(left_ty, left_val.as_base_value().into(), false),
|
||||||
(rhs, false),
|
(right_ty, rhs, false),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
@ -1941,8 +1966,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
None,
|
None,
|
||||||
(lhs, !is_ndarray1),
|
(left_ty, lhs, !is_ndarray1),
|
||||||
(rhs, !is_ndarray2),
|
(right_ty, rhs, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
@ -2771,8 +2796,12 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||||||
// elements over
|
// elements over
|
||||||
let subscripted_ndarray =
|
let subscripted_ndarray =
|
||||||
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
let ndarray =
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
NDArrayValue::from_pointer_value(subscripted_ndarray, llvm_usize, None);
|
subscripted_ndarray,
|
||||||
|
llvm_ndarray_data_t,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let num_dims = v.load_ndims(ctx);
|
let num_dims = v.load_ndims(ctx);
|
||||||
ndarray.store_ndims(
|
ndarray.store_ndims(
|
||||||
@ -3510,6 +3539,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||||
|
let llvm_ty = ctx.get_llvm_type(generator, *ty);
|
||||||
|
|
||||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||||
@ -3517,7 +3547,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let v = NDArrayValue::from_pointer_value(v, usize, None);
|
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None);
|
||||||
|
|
||||||
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ use super::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::{arraylike_flatten_element_type, PrimDef},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
@ -42,6 +42,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
@ -54,7 +55,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
|
|
||||||
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_usize, None))
|
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a dynamic shape.
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
@ -473,8 +474,8 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
res: NDArrayValue<'ctx>,
|
res: NDArrayValue<'ctx>,
|
||||||
lhs: (BasicValueEnum<'ctx>, bool),
|
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||||
rhs: (BasicValueEnum<'ctx>, bool),
|
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||||
value_fn: ValueFn,
|
value_fn: ValueFn,
|
||||||
) -> Result<NDArrayValue<'ctx>, String>
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
@ -487,8 +488,8 @@ where
|
|||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (lhs_val, lhs_scalar) = lhs;
|
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
||||||
let (rhs_val, rhs_scalar) = rhs;
|
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!(lhs_scalar && rhs_scalar),
|
!(lhs_scalar && rhs_scalar),
|
||||||
@ -499,14 +500,26 @@ where
|
|||||||
|
|
||||||
// Assert that all ndarray operands are broadcastable to the target size
|
// Assert that all ndarray operands are broadcastable to the target size
|
||||||
if !lhs_scalar {
|
if !lhs_scalar {
|
||||||
let lhs_val =
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
||||||
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
||||||
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
|
lhs_val.into_pointer_value(),
|
||||||
|
llvm_lhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !rhs_scalar {
|
if !rhs_scalar {
|
||||||
let rhs_val =
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
||||||
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
||||||
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
|
rhs_val.into_pointer_value(),
|
||||||
|
llvm_rhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -514,8 +527,14 @@ where
|
|||||||
let lhs_elem = if lhs_scalar {
|
let lhs_elem = if lhs_scalar {
|
||||||
lhs_val
|
lhs_val
|
||||||
} else {
|
} else {
|
||||||
let lhs =
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
||||||
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
||||||
|
let lhs = NDArrayValue::from_pointer_value(
|
||||||
|
lhs_val.into_pointer_value(),
|
||||||
|
llvm_lhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
||||||
|
|
||||||
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
|
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
|
||||||
@ -524,8 +543,14 @@ where
|
|||||||
let rhs_elem = if rhs_scalar {
|
let rhs_elem = if rhs_scalar {
|
||||||
rhs_val
|
rhs_val
|
||||||
} else {
|
} else {
|
||||||
let rhs =
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
||||||
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
||||||
|
let rhs = NDArrayValue::from_pointer_value(
|
||||||
|
rhs_val.into_pointer_value(),
|
||||||
|
llvm_rhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
||||||
|
|
||||||
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
|
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
|
||||||
@ -671,7 +696,7 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
value: BasicValueEnum<'ctx>,
|
(ty, value): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
@ -679,7 +704,9 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(v)
|
BasicValueEnum::PointerValue(v)
|
||||||
if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
|
if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
|
||||||
{
|
{
|
||||||
NDArrayValue::from_pointer_value(v, llvm_usize, None).load_ndims(ctx)
|
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
||||||
|
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
||||||
@ -694,7 +721,6 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
|
||||||
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||||
src_lst: ListValue<'ctx>,
|
src_lst: ListValue<'ctx>,
|
||||||
dim: u64,
|
dim: u64,
|
||||||
@ -727,6 +753,20 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|_, _| Ok(llvm_usize.const_int(1, false)),
|
|_, _| Ok(llvm_usize.const_int(1, false)),
|
||||||
|generator, ctx, _, i| {
|
|generator, ctx, _, i| {
|
||||||
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
|
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
|
||||||
|
let offset = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(
|
||||||
|
offset,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
dst_arr.get_type().element_type().size_of().unwrap(),
|
||||||
|
offset.get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let dst_ptr =
|
let dst_ptr =
|
||||||
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
|
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
|
||||||
@ -741,7 +781,6 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_from_ndlist_impl(
|
ndarray_from_ndlist_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
|
||||||
(dst_arr, dst_ptr),
|
(dst_arr, dst_ptr),
|
||||||
nested_lst_elem,
|
nested_lst_elem,
|
||||||
dim + 1,
|
dim + 1,
|
||||||
@ -760,7 +799,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
let lst_len = src_lst.load_size(ctx, None);
|
let lst_len = src_lst.load_size(ctx, None);
|
||||||
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
||||||
let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap();
|
let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
let cpy_len = ctx
|
let cpy_len = ctx
|
||||||
@ -816,7 +855,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
||||||
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
||||||
let object = NDArrayValue::from_pointer_value(object, llvm_usize, None);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
|
||||||
|
|
||||||
let ndarray = gen_if_else_expr_callback(
|
let ndarray = gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
@ -878,7 +918,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_sliced_copyto_impl(
|
ndarray_sliced_copyto_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
|
||||||
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||||
(object, object.data().base_ptr(ctx, generator)),
|
(object, object.data().base_ptr(ctx, generator)),
|
||||||
0,
|
0,
|
||||||
@ -892,6 +931,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
return Ok(NDArrayValue::from_pointer_value(
|
return Ok(NDArrayValue::from_pointer_value(
|
||||||
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
||||||
|
llvm_elem_ty,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
));
|
));
|
||||||
@ -1026,7 +1066,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_from_ndlist_impl(
|
ndarray_from_ndlist_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
|
||||||
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||||
object,
|
object,
|
||||||
0,
|
0,
|
||||||
@ -1099,7 +1138,6 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
|
||||||
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||||
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||||
dim: u64,
|
dim: u64,
|
||||||
@ -1108,10 +1146,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let llvm_i1 = ctx.ctx.bool_type();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
|
||||||
|
|
||||||
|
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
||||||
|
|
||||||
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
||||||
if slices.is_empty() {
|
if slices.is_empty() {
|
||||||
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
|
||||||
|
|
||||||
let stride = call_ndarray_calc_size(
|
let stride = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -1162,9 +1202,29 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|generator, ctx, _, src_i| {
|
|generator, ctx, _, src_i| {
|
||||||
// Calculate the offset of the active slice
|
// Calculate the offset of the active slice
|
||||||
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
|
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
|
||||||
|
let src_data_offset = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(
|
||||||
|
src_data_offset,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_cast(sizeof_elem, src_data_offset.get_type(), "")
|
||||||
|
.unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
let dst_i =
|
let dst_i =
|
||||||
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
|
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
|
||||||
|
let dst_data_offset = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(
|
||||||
|
dst_data_offset,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_cast(sizeof_elem, dst_data_offset.get_type(), "")
|
||||||
|
.unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let (src_ptr, dst_ptr) = unsafe {
|
let (src_ptr, dst_ptr) = unsafe {
|
||||||
(
|
(
|
||||||
@ -1176,7 +1236,6 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_sliced_copyto_impl(
|
ndarray_sliced_copyto_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
|
||||||
(dst_arr, dst_ptr),
|
(dst_arr, dst_ptr),
|
||||||
(src_arr, src_ptr),
|
(src_arr, src_ptr),
|
||||||
dim + 1,
|
dim + 1,
|
||||||
@ -1293,7 +1352,6 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_sliced_copyto_impl(
|
ndarray_sliced_copyto_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
|
||||||
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||||
(this, this.data().base_ptr(ctx, generator)),
|
(this, this.data().base_ptr(ctx, generator)),
|
||||||
0,
|
0,
|
||||||
@ -1376,8 +1434,8 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
|
|||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
res: Option<NDArrayValue<'ctx>>,
|
||||||
lhs: (BasicValueEnum<'ctx>, bool),
|
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||||
rhs: (BasicValueEnum<'ctx>, bool),
|
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||||
value_fn: ValueFn,
|
value_fn: ValueFn,
|
||||||
) -> Result<NDArrayValue<'ctx>, String>
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
@ -1390,8 +1448,8 @@ where
|
|||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (lhs_val, lhs_scalar) = lhs;
|
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
||||||
let (rhs_val, rhs_scalar) = rhs;
|
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!(lhs_scalar && rhs_scalar),
|
!(lhs_scalar && rhs_scalar),
|
||||||
@ -1402,10 +1460,22 @@ where
|
|||||||
|
|
||||||
let ndarray = res.unwrap_or_else(|| {
|
let ndarray = res.unwrap_or_else(|| {
|
||||||
if lhs_scalar && rhs_scalar {
|
if lhs_scalar && rhs_scalar {
|
||||||
let lhs_val =
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
||||||
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
||||||
let rhs_val =
|
let lhs_val = NDArrayValue::from_pointer_value(
|
||||||
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
|
lhs_val.into_pointer_value(),
|
||||||
|
llvm_lhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
||||||
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
||||||
|
let rhs_val = NDArrayValue::from_pointer_value(
|
||||||
|
rhs_val.into_pointer_value(),
|
||||||
|
llvm_rhs_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
||||||
|
|
||||||
@ -1421,8 +1491,14 @@ where
|
|||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
} else {
|
} else {
|
||||||
|
let dtype = arraylike_flatten_element_type(
|
||||||
|
&mut ctx.unifier,
|
||||||
|
if lhs_scalar { rhs_ty } else { lhs_ty },
|
||||||
|
);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
||||||
let ndarray = NDArrayValue::from_pointer_value(
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
||||||
|
llvm_elem_ty,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
@ -1981,11 +2057,18 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||||||
let this_arg =
|
let this_arg =
|
||||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||||
|
|
||||||
|
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
|
||||||
|
|
||||||
ndarray_copy_impl(
|
ndarray_copy_impl(
|
||||||
generator,
|
generator,
|
||||||
context,
|
context,
|
||||||
this_elem_ty,
|
this_elem_ty,
|
||||||
NDArrayValue::from_pointer_value(this_arg.into_pointer_value(), llvm_usize, None),
|
NDArrayValue::from_pointer_value(
|
||||||
|
this_arg.into_pointer_value(),
|
||||||
|
llvm_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.map(NDArrayValue::into)
|
.map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
@ -2004,6 +2087,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
|
let this_elem_ty = arraylike_flatten_element_type(&mut context.unifier, this_ty);
|
||||||
let this_arg = obj
|
let this_arg = obj
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -2014,10 +2098,12 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
let value_ty = fun.0.args[0].ty;
|
let value_ty = fun.0.args[0].ty;
|
||||||
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
||||||
|
|
||||||
|
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
|
||||||
|
|
||||||
ndarray_fill_flattened(
|
ndarray_fill_flattened(
|
||||||
generator,
|
generator,
|
||||||
context,
|
context,
|
||||||
NDArrayValue::from_pointer_value(this_arg, llvm_usize, None),
|
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|
||||||
|generator, ctx, _| {
|
|generator, ctx, _| {
|
||||||
let value = if value_arg.is_pointer_value() {
|
let value = if value_arg.is_pointer_value() {
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
@ -2058,7 +2144,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
// Dimensions are reversed in the transposed array
|
||||||
@ -2177,7 +2264,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
@ -2454,14 +2542,19 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "ndarray_dot";
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
let (x1_ty, x1) = x1;
|
let (x1_ty, x1) = x1;
|
||||||
let (_, x2) = x2;
|
let (x2_ty, x2) = x2;
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
match (x1, x2) {
|
match (x1, x2) {
|
||||||
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
|
let n1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
|
||||||
let n2 = NDArrayValue::from_pointer_value(n2, llvm_usize, None);
|
let n2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty);
|
||||||
|
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
||||||
|
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
||||||
|
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
@ -2501,7 +2594,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.as_basic_value_enum(),
|
.as_basic_value_enum(),
|
||||||
_ => codegen_unreachable!(ctx),
|
_ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
|
||||||
};
|
};
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
let acc_val = match acc_val {
|
let acc_val = match acc_val {
|
||||||
@ -2515,7 +2608,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.build_float_add(e1, product.into_float_value(), "")
|
.build_float_add(e1, product.into_float_value(), "")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.as_basic_value_enum(),
|
.as_basic_value_enum(),
|
||||||
_ => codegen_unreachable!(ctx),
|
_ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
|
||||||
};
|
};
|
||||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
|
||||||
|
@ -67,21 +67,29 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
||||||
let Ok(_) = PointerType::try_from(ndarray_data_ty) else {
|
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
||||||
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
||||||
};
|
};
|
||||||
|
let ndarray_data = ndarray_pdata.get_element_type();
|
||||||
|
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if ndarray_data.get_bit_width() != 8 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||||
|
ndarray_data.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn llvm_type(
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
ctx: &'ctx Context,
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
) -> PointerType<'ctx> {
|
|
||||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
|
|
||||||
//
|
//
|
||||||
// * num_dims: Number of dimensions in the array
|
// * num_dims: Number of dimensions in the array
|
||||||
// * dims: Pointer to an array containing the size of each dimension
|
// * dims: Pointer to an array containing the size of each dimension
|
||||||
@ -89,13 +97,13 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
let field_tys = [
|
let field_tys = [
|
||||||
llvm_usize.into(),
|
llvm_usize.into(),
|
||||||
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
||||||
dtype.ptr_type(AddressSpace::default()).into(),
|
ctx.i8_type().ptr_type(AddressSpace::default()).into(),
|
||||||
];
|
];
|
||||||
|
|
||||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an instance of [`ListType`].
|
/// Creates an instance of [`NDArrayType`].
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new<G: CodeGenerator + ?Sized>(
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
generator: &G,
|
generator: &G,
|
||||||
@ -103,24 +111,21 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let llvm_usize = generator.get_size_type(ctx);
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
let llvm_ndarray = Self::llvm_type(ctx, dtype, llvm_usize);
|
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||||
|
|
||||||
NDArrayType::from_type(llvm_ndarray, llvm_usize)
|
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [`NDArrayType`] from a [`PointerType`].
|
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
pub fn from_type(
|
||||||
|
ptr_ty: PointerType<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Self {
|
||||||
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
|
|
||||||
NDArrayType {
|
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
|
||||||
ty: ptr_ty,
|
|
||||||
dtype: ptr_ty
|
|
||||||
.get_element_type()
|
|
||||||
.try_into()
|
|
||||||
.expect("Expected BasicTypeEnum for dtype of NDArray"),
|
|
||||||
llvm_usize,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the type of the `size` field of this `ndarray` type.
|
/// Returns the type of the `size` field of this `ndarray` type.
|
||||||
@ -207,7 +212,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
|||||||
) -> Self::Value {
|
) -> Self::Value {
|
||||||
debug_assert_eq!(value.get_type(), self.as_base_type());
|
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||||
|
|
||||||
NDArrayValue::from_pointer_value(value, self.llvm_usize, name)
|
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_base_type(&self) -> Self::Base {
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{AnyTypeEnum, BasicTypeEnum, IntType},
|
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
@ -20,6 +20,7 @@ use crate::codegen::{
|
|||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct NDArrayValue<'ctx> {
|
pub struct NDArrayValue<'ctx> {
|
||||||
value: PointerValue<'ctx>,
|
value: PointerValue<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
}
|
}
|
||||||
@ -38,12 +39,13 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_pointer_value(
|
pub fn from_pointer_value(
|
||||||
ptr: PointerValue<'ctx>,
|
ptr: PointerValue<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
NDArrayValue { value: ptr, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
@ -138,6 +140,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
|
|
||||||
/// Stores the array of data elements `data` into this instance.
|
/// Stores the array of data elements `data` into this instance.
|
||||||
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||||
|
.unwrap();
|
||||||
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,7 +155,15 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
elem_ty: BasicTypeEnum<'ctx>,
|
elem_ty: BasicTypeEnum<'ctx>,
|
||||||
size: IntValue<'ctx>,
|
size: IntValue<'ctx>,
|
||||||
) {
|
) {
|
||||||
self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap());
|
let itemsize =
|
||||||
|
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
|
||||||
|
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
||||||
|
|
||||||
|
// TODO: What about alignment?
|
||||||
|
self.store_data(
|
||||||
|
ctx,
|
||||||
|
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||||
@ -164,7 +178,7 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
|||||||
type Type = NDArrayType<'ctx>;
|
type Type = NDArrayType<'ctx>;
|
||||||
|
|
||||||
fn get_type(&self) -> Self::Type {
|
fn get_type(&self) -> Self::Type {
|
||||||
NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize)
|
NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_base_value(&self) -> Self::Base {
|
fn as_base_value(&self) -> Self::Base {
|
||||||
@ -282,10 +296,10 @@ pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
|||||||
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
fn element_type<G: CodeGenerator + ?Sized>(
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
_: &G,
|
||||||
) -> AnyTypeEnum<'ctx> {
|
) -> AnyTypeEnum<'ctx> {
|
||||||
self.0.data().base_ptr(ctx, generator).get_type().get_element_type()
|
self.0.dtype.as_any_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
@ -318,15 +332,37 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
idx: &IntValue<'ctx>,
|
idx: &IntValue<'ctx>,
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
unsafe {
|
let sizeof_elem = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
self.element_type(ctx, generator).size_of().unwrap(),
|
||||||
|
idx.get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
|
||||||
|
let ptr = unsafe {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_in_bounds_gep(
|
.build_in_bounds_gep(
|
||||||
self.base_ptr(ctx, generator),
|
self.base_ptr(ctx, generator),
|
||||||
&[*idx],
|
&[idx],
|
||||||
name.unwrap_or_default(),
|
name.unwrap_or_default(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
};
|
||||||
|
|
||||||
|
// Current implementation is transparent - The returned pointer type is
|
||||||
|
// already cast into the expected type, allowing for immediately
|
||||||
|
// load/store.
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
@ -347,7 +383,20 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
|
||||||
|
|
||||||
|
// Current implementation is transparent - The returned pointer type is
|
||||||
|
// already cast into the expected type, allowing for immediately
|
||||||
|
// load/store.
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,8 +430,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||||||
);
|
);
|
||||||
|
|
||||||
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
||||||
|
let sizeof_elem = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
self.element_type(ctx, generator).size_of().unwrap(),
|
||||||
|
index.get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
|
||||||
|
|
||||||
unsafe {
|
let ptr = unsafe {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_in_bounds_gep(
|
.build_in_bounds_gep(
|
||||||
self.base_ptr(ctx, generator),
|
self.base_ptr(ctx, generator),
|
||||||
@ -390,7 +448,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||||||
name.unwrap_or_default(),
|
name.unwrap_or_default(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
};
|
||||||
|
// TODO: Current implementation is transparent
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
@ -455,7 +523,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) }
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) };
|
||||||
|
// TODO: Current implementation is transparent
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,6 +144,7 @@ def test_ndarray_array():
|
|||||||
|
|
||||||
# Copy
|
# Copy
|
||||||
n2_cpy: ndarray[float, 2] = np_array(n2, copy=False)
|
n2_cpy: ndarray[float, 2] = np_array(n2, copy=False)
|
||||||
|
output_ndarray_float_2(n2_cpy)
|
||||||
n2_cpy.fill(0.0)
|
n2_cpy.fill(0.0)
|
||||||
output_ndarray_float_2(n2_cpy)
|
output_ndarray_float_2(n2_cpy)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user