forked from M-Labs/nac3
[core] codegen/ndarray: Make ndims non-optional
Now that everything is ported to use strided impl, dynamic-ndim ndarray instances do not exist anymore.
This commit is contained in:
parent
3ac1083734
commit
12fddc3533
@ -464,7 +464,7 @@ fn format_rpc_arg<'ctx>(
|
|||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
let dtype = ctx.get_llvm_type(generator, elem_ty);
|
let dtype = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims))
|
let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims)
|
||||||
.map_value(arg.into_pointer_value(), None);
|
.map_value(arg.into_pointer_value(), None);
|
||||||
|
|
||||||
let ndims = llvm_usize.const_int(ndims, false);
|
let ndims = llvm_usize.const_int(ndims, false);
|
||||||
@ -597,7 +597,7 @@ fn format_rpc_ret<'ctx>(
|
|||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||||
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims))
|
let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
|
|
||||||
// NOTE: Current content of `ndarray`:
|
// NOTE: Current content of `ndarray`:
|
||||||
|
@ -1107,7 +1107,7 @@ impl InnerResolver {
|
|||||||
self.global_value_ids.write().insert(id, obj.into());
|
self.global_value_ids.write().insert(id, obj.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndims = llvm_ndarray.ndims().unwrap();
|
let ndims = llvm_ndarray.ndims();
|
||||||
|
|
||||||
// Obtain the shape of the ndarray
|
// Obtain the shape of the ndarray
|
||||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||||
|
@ -1652,7 +1652,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
unsafe { out.create_data(generator, ctx) };
|
unsafe { out.create_data(generator, ctx) };
|
||||||
@ -1694,7 +1694,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
};
|
};
|
||||||
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
|
||||||
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
||||||
unsafe { q.create_data(generator, ctx) };
|
unsafe { q.create_data(generator, ctx) };
|
||||||
|
|
||||||
@ -1746,8 +1746,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
};
|
};
|
||||||
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1));
|
let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1);
|
||||||
let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
|
||||||
|
|
||||||
let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
|
let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
|
||||||
unsafe { u.create_data(generator, ctx) };
|
unsafe { u.create_data(generator, ctx) };
|
||||||
@ -1796,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
unsafe { out.create_data(generator, ctx) };
|
unsafe { out.create_data(generator, ctx) };
|
||||||
@ -1838,7 +1838,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
};
|
};
|
||||||
|
|
||||||
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
|
||||||
.construct_dyn_shape(generator, ctx, &[d0, d1], None);
|
.construct_dyn_shape(generator, ctx, &[d0, d1], None);
|
||||||
unsafe { out.create_data(generator, ctx) };
|
unsafe { out.create_data(generator, ctx) };
|
||||||
|
|
||||||
@ -1880,7 +1880,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
};
|
};
|
||||||
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
|
||||||
|
|
||||||
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
||||||
unsafe { l.create_data(generator, ctx) };
|
unsafe { l.create_data(generator, ctx) };
|
||||||
@ -1924,7 +1924,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, ndims, llvm_usize, None);
|
||||||
|
|
||||||
if !x1.get_type().element_type().is_float_type() {
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
@ -1940,7 +1940,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.construct_unsized(generator, ctx, &x2, None); // x2.shape == []
|
.construct_unsized(generator, ctx, &x2, None); // x2.shape == []
|
||||||
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
|
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
|
||||||
|
|
||||||
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
unsafe { out.create_data(generator, ctx) };
|
unsafe { out.create_data(generator, ctx) };
|
||||||
@ -1979,7 +1979,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
|
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
|
||||||
let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1))
|
let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1)
|
||||||
.construct_const_shape(generator, ctx, &[1], None);
|
.construct_const_shape(generator, ctx, &[1], None);
|
||||||
unsafe { det.create_data(generator, ctx) };
|
unsafe { det.create_data(generator, ctx) };
|
||||||
|
|
||||||
@ -2008,13 +2008,13 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
|
|
||||||
let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None);
|
let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None);
|
||||||
assert_eq!(x1.get_type().ndims(), Some(2));
|
assert_eq!(x1.get_type().ndims(), 2);
|
||||||
|
|
||||||
if !x1.get_type().element_type().is_float_type() {
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
|
||||||
|
|
||||||
let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
|
let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
|
||||||
t.copy_shape_from_ndarray(generator, ctx, x1);
|
t.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
@ -2053,13 +2053,13 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
|
|
||||||
let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None);
|
let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None);
|
||||||
assert_eq!(x1.get_type().ndims(), Some(2));
|
assert_eq!(x1.get_type().ndims(), 2);
|
||||||
|
|
||||||
if !x1.get_type().element_type().is_float_type() {
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
|
||||||
|
|
||||||
let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
|
let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
|
||||||
h.copy_shape_from_ndarray(generator, ctx, x1);
|
h.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
@ -520,7 +520,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||||
);
|
);
|
||||||
|
|
||||||
NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into()
|
NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => unreachable!(
|
_ => unreachable!(
|
||||||
|
@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>(
|
|||||||
|
|
||||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
|
||||||
.construct_numpy_empty(generator, context, &shape, None);
|
.construct_numpy_empty(generator, context, &shape, None);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>(
|
|||||||
|
|
||||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
|
||||||
.construct_numpy_zeros(generator, context, dtype, &shape, None);
|
.construct_numpy_zeros(generator, context, dtype, &shape, None);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>(
|
|||||||
|
|
||||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
|
||||||
.construct_numpy_ones(generator, context, dtype, &shape, None);
|
.construct_numpy_ones(generator, context, dtype, &shape, None);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
@ -120,8 +120,13 @@ pub fn gen_ndarray_full<'ctx>(
|
|||||||
|
|
||||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full(
|
||||||
.construct_numpy_full(generator, context, &shape, fill_value_arg, None);
|
generator,
|
||||||
|
context,
|
||||||
|
&shape,
|
||||||
|
fill_value_arg,
|
||||||
|
None,
|
||||||
|
);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>(
|
|||||||
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
|
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2)
|
||||||
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
|
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
@ -246,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>(
|
|||||||
.builder
|
.builder
|
||||||
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
|
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2)
|
||||||
.construct_numpy_identity(generator, context, dtype, n, None);
|
.construct_numpy_identity(generator, context, dtype, n, None);
|
||||||
Ok(ndarray.as_base_value())
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
@ -315,8 +320,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
||||||
|
|
||||||
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
|
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
|
||||||
assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1));
|
assert_eq!(a.get_type().ndims(), 1);
|
||||||
assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1));
|
assert_eq!(b.get_type().ndims(), 1);
|
||||||
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
|
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
|
||||||
|
|
||||||
// Check shapes.
|
// Check shapes.
|
||||||
|
@ -447,10 +447,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
|||||||
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
|
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
|
||||||
.to_ndarray(generator, ctx);
|
.to_ndarray(generator, ctx);
|
||||||
|
|
||||||
let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()]
|
let broadcast_ndims =
|
||||||
.iter()
|
[target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap();
|
||||||
.filter_map(|ndims| *ndims)
|
|
||||||
.max();
|
|
||||||
let broadcast_result = NDArrayType::new(
|
let broadcast_result = NDArrayType::new(
|
||||||
generator,
|
generator,
|
||||||
ctx.ctx,
|
ctx.ctx,
|
||||||
|
@ -464,6 +464,6 @@ fn test_classes_ndarray_type_new() {
|
|||||||
let llvm_i32 = ctx.i32_type();
|
let llvm_i32 = ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(&ctx);
|
let llvm_usize = generator.get_size_type(&ctx);
|
||||||
|
|
||||||
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None);
|
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2);
|
||||||
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty);
|
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty);
|
||||||
assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int));
|
assert!(self.ndims >= ndims_int);
|
||||||
assert_eq!(dtype, self.dtype);
|
assert_eq!(dtype, self.dtype);
|
||||||
|
|
||||||
let list_value = list.as_i8_list(generator, ctx);
|
let list_value = list.as_i8_list(generator, ctx);
|
||||||
@ -61,7 +61,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
generator, ctx, list_value, ndims, &shape,
|
generator, ctx, list_value, ndims, &shape,
|
||||||
);
|
);
|
||||||
|
|
||||||
let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int))
|
let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int)
|
||||||
.construct_uninitialized(generator, ctx, name);
|
.construct_uninitialized(generator, ctx, name);
|
||||||
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||||
unsafe { ndarray.create_data(generator, ctx) };
|
unsafe { ndarray.create_data(generator, ctx) };
|
||||||
@ -93,12 +93,12 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
if ndims == 1 {
|
if ndims == 1 {
|
||||||
// `list` is not nested
|
// `list` is not nested
|
||||||
assert_eq!(ndims, 1);
|
assert_eq!(ndims, 1);
|
||||||
assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims));
|
assert!(self.ndims >= ndims);
|
||||||
assert_eq!(dtype, self.dtype);
|
assert_eq!(dtype, self.dtype);
|
||||||
|
|
||||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1))
|
let ndarray = Self::new(generator, ctx.ctx, dtype, 1)
|
||||||
.construct_uninitialized(generator, ctx, name);
|
.construct_uninitialized(generator, ctx, name);
|
||||||
|
|
||||||
// Set data
|
// Set data
|
||||||
@ -170,7 +170,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None)
|
NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of `np_array(<ndarray>, copy=copy)`.
|
/// Implementation of `np_array(<ndarray>, copy=copy)`.
|
||||||
@ -183,9 +183,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
assert_eq!(ndarray.get_type().dtype, self.dtype);
|
assert_eq!(ndarray.get_type().dtype, self.dtype);
|
||||||
assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self
|
assert!(self.ndims >= ndarray.get_type().ndims);
|
||||||
.ndims
|
|
||||||
.is_none_or(|self_ndims| self_ndims >= ndarray_ndims)));
|
|
||||||
assert_eq!(copy.get_type(), ctx.ctx.bool_type());
|
assert_eq!(copy.get_type(), ctx.ctx.bool_type());
|
||||||
|
|
||||||
let ndarray_val = gen_if_else_expr_callback(
|
let ndarray_val = gen_if_else_expr_callback(
|
||||||
|
@ -47,7 +47,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
NDArrayOut::NewNDArray { dtype } => {
|
NDArrayOut::NewNDArray { dtype } => {
|
||||||
// Create a new ndarray based on the broadcast shape.
|
// Create a new ndarray based on the broadcast shape.
|
||||||
let result_ndarray =
|
let result_ndarray =
|
||||||
NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims))
|
NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
result_ndarray.copy_shape_from_array(
|
result_ndarray.copy_shape_from_array(
|
||||||
generator,
|
generator,
|
||||||
|
@ -38,7 +38,7 @@ mod nditer;
|
|||||||
pub struct NDArrayType<'ctx> {
|
pub struct NDArrayType<'ctx> {
|
||||||
ty: PointerType<'ctx>,
|
ty: PointerType<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
ndims: u64,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
generator: &G,
|
generator: &G,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
ndims: u64,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let llvm_usize = generator.get_size_type(ctx);
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||||
@ -132,7 +132,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(!inputs.is_empty());
|
assert!(!inputs.is_empty());
|
||||||
|
|
||||||
Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max())
|
Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
||||||
@ -145,7 +145,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
let llvm_usize = generator.get_size_type(ctx);
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||||
|
|
||||||
NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize }
|
NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
||||||
@ -164,7 +164,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
NDArrayType {
|
NDArrayType {
|
||||||
ty: Self::llvm_type(ctx.ctx, llvm_usize),
|
ty: Self::llvm_type(ctx.ctx, llvm_usize),
|
||||||
dtype: llvm_dtype,
|
dtype: llvm_dtype,
|
||||||
ndims: Some(ndims),
|
ndims,
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,7 +174,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
pub fn from_type(
|
pub fn from_type(
|
||||||
ptr_ty: PointerType<'ctx>,
|
ptr_ty: PointerType<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
ndims: u64,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
@ -196,7 +196,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
|
|
||||||
/// Returns the number of dimensions of this `ndarray` type.
|
/// Returns the number of dimensions of this `ndarray` type.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn ndims(&self) -> Option<u64> {
|
pub fn ndims(&self) -> u64 {
|
||||||
self.ndims
|
self.ndims
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,35 +286,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
let ndims = self.llvm_usize.const_int(self.ndims, false);
|
||||||
|
|
||||||
let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
self.construct_impl(generator, ctx, ndims, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`.
|
|
||||||
///
|
|
||||||
/// `shape` and `strides` will be automatically allocated onto the stack.
|
|
||||||
///
|
|
||||||
/// The returned ndarray's content will be:
|
|
||||||
/// - `data`: uninitialized.
|
|
||||||
/// - `itemsize`: set to the size of `dtype`.
|
|
||||||
/// - `ndims`: set to the value of `ndims`.
|
|
||||||
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
#[deprecated = "Prefer construct_uninitialized or construct_*_shape."]
|
|
||||||
#[must_use]
|
|
||||||
pub fn construct_dyn_ndims<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
name: Option<&'ctx str>,
|
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
|
||||||
assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)");
|
|
||||||
|
|
||||||
self.construct_impl(generator, ctx, ndims, name)
|
self.construct_impl(generator, ctx, ndims, name)
|
||||||
}
|
}
|
||||||
@ -330,9 +302,9 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
shape: &[u64],
|
shape: &[u64],
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims));
|
assert_eq!(shape.len() as u64, self.ndims);
|
||||||
|
|
||||||
let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64))
|
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
|
||||||
.construct_uninitialized(generator, ctx, name);
|
.construct_uninitialized(generator, ctx, name);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
@ -365,9 +337,9 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
shape: &[IntValue<'ctx>],
|
shape: &[IntValue<'ctx>],
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims));
|
assert_eq!(shape.len() as u64, self.ndims);
|
||||||
|
|
||||||
let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64))
|
let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
|
||||||
.construct_uninitialized(generator, ctx, name);
|
.construct_uninitialized(generator, ctx, name);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
@ -407,7 +379,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
let value = value.as_basic_value_enum();
|
let value = value.as_basic_value_enum();
|
||||||
|
|
||||||
assert_eq!(value.get_type(), self.dtype);
|
assert_eq!(value.get_type(), self.dtype);
|
||||||
assert!(self.ndims.is_none_or(|ndims| ndims == 0));
|
assert_eq!(self.ndims, 0);
|
||||||
|
|
||||||
// We have to put the value on the stack to get a data pointer.
|
// We have to put the value on the stack to get a data pointer.
|
||||||
let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap();
|
let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap();
|
||||||
|
@ -163,13 +163,8 @@ impl<'ctx> NDIterType<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
assert!(
|
|
||||||
ndarray.get_type().ndims().is_some(),
|
|
||||||
"NDIter requires ndims of NDArray to be known."
|
|
||||||
);
|
|
||||||
|
|
||||||
let nditer = self.raw_alloca_var(generator, ctx, None);
|
let nditer = self.raw_alloca_var(generator, ctx, None);
|
||||||
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false);
|
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false);
|
||||||
|
|
||||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||||
let indices =
|
let indices =
|
||||||
|
@ -101,11 +101,10 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
target_ndims: u64,
|
target_ndims: u64,
|
||||||
target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims));
|
assert!(self.ndims <= target_ndims);
|
||||||
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
|
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
|
||||||
|
|
||||||
let broadcast_ndarray =
|
let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims)
|
||||||
NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims))
|
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
broadcast_ndarray.copy_shape_from_array(
|
broadcast_ndarray.copy_shape_from_array(
|
||||||
generator,
|
generator,
|
||||||
@ -199,14 +198,13 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
ndarrays: &[NDArrayValue<'ctx>],
|
ndarrays: &[NDArrayValue<'ctx>],
|
||||||
) -> BroadcastAllResult<'ctx, G> {
|
) -> BroadcastAllResult<'ctx, G> {
|
||||||
assert!(!ndarrays.is_empty());
|
assert!(!ndarrays.is_empty());
|
||||||
assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some()));
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
// Infer the broadcast output ndims.
|
// Infer the broadcast output ndims.
|
||||||
let broadcast_ndims_int =
|
let broadcast_ndims_int =
|
||||||
ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap();
|
ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap();
|
||||||
assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int));
|
assert!(self.ndims() >= broadcast_ndims_int);
|
||||||
|
|
||||||
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false);
|
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false);
|
||||||
let broadcast_shape = ArraySliceValue::from_ptr_val(
|
let broadcast_shape = ArraySliceValue::from_ptr_val(
|
||||||
@ -223,10 +221,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
let shape_entries = ndarrays
|
let shape_entries = ndarrays
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ndarray| {
|
.map(|ndarray| {
|
||||||
(
|
(ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims())
|
||||||
ndarray.shape().as_slice_value(ctx, generator),
|
|
||||||
ndarray.get_type().ndims().unwrap(),
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape);
|
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape);
|
||||||
|
@ -121,9 +121,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
.alloca_var(generator, ctx, self.name);
|
.alloca_var(generator, ctx, self.name);
|
||||||
|
|
||||||
// Set ndims and shape.
|
// Set ndims and shape.
|
||||||
let ndims = self
|
let ndims = self.llvm_usize.const_int(self.ndims, false);
|
||||||
.ndims
|
|
||||||
.map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false));
|
|
||||||
result.store_ndims(ctx, ndims);
|
result.store_ndims(ctx, ndims);
|
||||||
|
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
@ -180,7 +178,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
||||||
|
|
||||||
// Allocate the resulting ndarray.
|
// Allocate the resulting ndarray.
|
||||||
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims))
|
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims)
|
||||||
.construct_uninitialized(generator, ctx, carray.name);
|
.construct_uninitialized(generator, ctx, carray.name);
|
||||||
|
|
||||||
// Copy shape and update strides
|
// Copy shape and update strides
|
||||||
|
@ -98,8 +98,8 @@ impl<'ctx> From<NDIndexValue<'ctx>> for PointerValue<'ctx> {
|
|||||||
impl<'ctx> NDArrayValue<'ctx> {
|
impl<'ctx> NDArrayValue<'ctx> {
|
||||||
/// Get the expected `ndims` after indexing with `indices`.
|
/// Get the expected `ndims` after indexing with `indices`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option<u64> {
|
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 {
|
||||||
let mut ndims = self.ndims?;
|
let mut ndims = self.ndims;
|
||||||
|
|
||||||
for index in indices {
|
for index in indices {
|
||||||
match index {
|
match index {
|
||||||
@ -113,7 +113,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(ndims)
|
ndims
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
|
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
|
||||||
@ -127,8 +127,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
indices: &[RustNDIndex<'ctx>],
|
indices: &[RustNDIndex<'ctx>],
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
|
|
||||||
|
|
||||||
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
|
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
|
||||||
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims)
|
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
|
@ -29,16 +29,8 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
|
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
|
||||||
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
|
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
|
||||||
) -> NDArrayValue<'ctx> {
|
) -> NDArrayValue<'ctx> {
|
||||||
assert!(
|
assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims);
|
||||||
in_a.ndims.is_some_and(|ndims| ndims >= 2),
|
assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims);
|
||||||
"in_a (which is {:?}) must be compile-time known and >= 2",
|
|
||||||
in_a.ndims
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
in_b.ndims.is_some_and(|ndims| ndims >= 2),
|
|
||||||
"in_b (which is {:?}) must be compile-time known and >= 2",
|
|
||||||
in_b.ndims
|
|
||||||
);
|
|
||||||
|
|
||||||
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
|
||||||
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
|
||||||
@ -47,13 +39,13 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
|
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
|
||||||
|
|
||||||
// Deduce ndims of the result of matmul.
|
// Deduce ndims of the result of matmul.
|
||||||
let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap());
|
let ndims_int = max(in_a.ndims, in_b.ndims);
|
||||||
let ndims = llvm_usize.const_int(ndims_int, false);
|
let ndims = llvm_usize.const_int(ndims_int, false);
|
||||||
|
|
||||||
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
||||||
// destination ndarray to store the result of matmul.
|
// destination ndarray to store the result of matmul.
|
||||||
let (lhs, rhs, dst) = {
|
let (lhs, rhs, dst) = {
|
||||||
let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false);
|
let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false);
|
||||||
let in_lhs_shape = TypedArrayLikeAdapter::from(
|
let in_lhs_shape = TypedArrayLikeAdapter::from(
|
||||||
ArraySliceValue::from_ptr_val(
|
ArraySliceValue::from_ptr_val(
|
||||||
in_a.shape().base_ptr(ctx, generator),
|
in_a.shape().base_ptr(ctx, generator),
|
||||||
@ -63,7 +55,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
|_, _, val| val.into_int_value(),
|
|_, _, val| val.into_int_value(),
|
||||||
|_, _, val| val.into(),
|
|_, _, val| val.into(),
|
||||||
);
|
);
|
||||||
let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false);
|
let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false);
|
||||||
let in_rhs_shape = TypedArrayLikeAdapter::from(
|
let in_rhs_shape = TypedArrayLikeAdapter::from(
|
||||||
ArraySliceValue::from_ptr_val(
|
ArraySliceValue::from_ptr_val(
|
||||||
in_b.shape().base_ptr(ctx, generator),
|
in_b.shape().base_ptr(ctx, generator),
|
||||||
@ -116,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
|
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
|
||||||
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
|
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
|
||||||
|
|
||||||
let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int))
|
let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
|
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
|
||||||
unsafe {
|
unsafe {
|
||||||
@ -266,10 +258,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
(out_dtype, out): (Type, NDArrayOut<'ctx>),
|
(out_dtype, out): (Type, NDArrayOut<'ctx>),
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Sanity check, but type inference should prevent this.
|
// Sanity check, but type inference should prevent this.
|
||||||
assert!(
|
assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input");
|
||||||
self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0),
|
|
||||||
"np.matmul disallows scalar input"
|
|
||||||
);
|
|
||||||
|
|
||||||
// If both arguments are 2-D they are multiplied like conventional matrices.
|
// If both arguments are 2-D they are multiplied like conventional matrices.
|
||||||
//
|
//
|
||||||
@ -282,14 +271,14 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
// If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its
|
// If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its
|
||||||
// dimensions. After matrix multiplication the appended 1 is removed.
|
// dimensions. After matrix multiplication the appended 1 is removed.
|
||||||
|
|
||||||
let new_a = if self.ndims.unwrap() == 1 {
|
let new_a = if self.ndims == 1 {
|
||||||
// Prepend 1 to its dimensions
|
// Prepend 1 to its dimensions
|
||||||
self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||||
} else {
|
} else {
|
||||||
*self
|
*self
|
||||||
};
|
};
|
||||||
|
|
||||||
let new_b = if other.ndims.unwrap() == 1 {
|
let new_b = if other.ndims == 1 {
|
||||||
// Append 1 to its dimensions
|
// Append 1 to its dimensions
|
||||||
other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||||
} else {
|
} else {
|
||||||
@ -305,12 +294,12 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
let mut postindices = vec![];
|
let mut postindices = vec![];
|
||||||
let zero = ctx.ctx.i32_type().const_zero();
|
let zero = ctx.ctx.i32_type().const_zero();
|
||||||
|
|
||||||
if self.ndims.unwrap() == 1 {
|
if self.ndims == 1 {
|
||||||
// Remove the prepended 1
|
// Remove the prepended 1
|
||||||
postindices.push(RustNDIndex::SingleElement(zero));
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
}
|
}
|
||||||
|
|
||||||
if other.ndims.unwrap() == 1 {
|
if other.ndims == 1 {
|
||||||
// Remove the appended 1
|
// Remove the appended 1
|
||||||
postindices.push(RustNDIndex::Ellipsis);
|
postindices.push(RustNDIndex::Ellipsis);
|
||||||
postindices.push(RustNDIndex::SingleElement(zero));
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
@ -42,7 +42,7 @@ mod view;
|
|||||||
pub struct NDArrayValue<'ctx> {
|
pub struct NDArrayValue<'ctx> {
|
||||||
value: PointerValue<'ctx>,
|
value: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
ndims: u64,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
}
|
}
|
||||||
@ -62,7 +62,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
pub fn from_pointer_value(
|
pub fn from_pointer_value(
|
||||||
ptr: PointerValue<'ctx>,
|
ptr: PointerValue<'ctx>,
|
||||||
dtype: BasicTypeEnum<'ctx>,
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
ndims: Option<u64>,
|
ndims: u64,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
name: Option<&'ctx str>,
|
name: Option<&'ctx str>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -245,26 +245,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
) {
|
) {
|
||||||
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
|
|
||||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||||
} else {
|
|
||||||
let self_ndims = self.load_ndims(ctx);
|
|
||||||
let src_ndims = src_ndarray.load_ndims(ctx);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
self_ndims,
|
|
||||||
src_ndims,
|
|
||||||
""
|
|
||||||
).unwrap(),
|
|
||||||
"0:AssertionError",
|
|
||||||
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
|
|
||||||
[Some(self_ndims), Some(src_ndims), None],
|
|
||||||
ctx.current_loc
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
|
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
|
||||||
self.copy_shape_from_array(generator, ctx, src_shape);
|
self.copy_shape_from_array(generator, ctx, src_shape);
|
||||||
@ -296,26 +277,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
) {
|
) {
|
||||||
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
|
|
||||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||||
} else {
|
|
||||||
let self_ndims = self.load_ndims(ctx);
|
|
||||||
let src_ndims = src_ndarray.load_ndims(ctx);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
self_ndims,
|
|
||||||
src_ndims,
|
|
||||||
""
|
|
||||||
).unwrap(),
|
|
||||||
"0:AssertionError",
|
|
||||||
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
|
|
||||||
[Some(self_ndims), Some(src_ndims), None],
|
|
||||||
ctx.current_loc
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
|
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
|
||||||
self.copy_strides_from_array(generator, ctx, src_strides);
|
self.copy_strides_from_array(generator, ctx, src_strides);
|
||||||
@ -380,11 +342,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let clone = if self.ndims.is_some() {
|
let clone = self.get_type().construct_uninitialized(generator, ctx, None);
|
||||||
self.get_type().construct_uninitialized(generator, ctx, None)
|
|
||||||
} else {
|
|
||||||
self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||||
@ -437,11 +395,9 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> TupleValue<'ctx> {
|
) -> TupleValue<'ctx> {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
let objects = (0..self.ndims.unwrap())
|
let objects = (0..self.ndims)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let dim = unsafe {
|
let dim = unsafe {
|
||||||
self.shape().get_typed_unchecked(
|
self.shape().get_typed_unchecked(
|
||||||
@ -459,7 +415,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
TupleType::new(
|
TupleType::new(
|
||||||
generator,
|
generator,
|
||||||
ctx.ctx,
|
ctx.ctx,
|
||||||
&repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(),
|
&repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(),
|
||||||
)
|
)
|
||||||
.construct_from_objects(ctx, objects, None)
|
.construct_from_objects(ctx, objects, None)
|
||||||
}
|
}
|
||||||
@ -473,11 +429,9 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> TupleValue<'ctx> {
|
) -> TupleValue<'ctx> {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
let objects = (0..self.ndims.unwrap())
|
let objects = (0..self.ndims)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let dim = unsafe {
|
let dim = unsafe {
|
||||||
self.strides().get_typed_unchecked(
|
self.strides().get_typed_unchecked(
|
||||||
@ -495,15 +449,15 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
TupleType::new(
|
TupleType::new(
|
||||||
generator,
|
generator,
|
||||||
ctx.ctx,
|
ctx.ctx,
|
||||||
&repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(),
|
&repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(),
|
||||||
)
|
)
|
||||||
.construct_from_objects(ctx, objects, None)
|
.construct_from_objects(ctx, objects, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn is_unsized(&self) -> Option<bool> {
|
pub fn is_unsized(&self) -> bool {
|
||||||
self.ndims.map(|ndims| ndims == 0)
|
self.ndims == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the element present in this `ndarray` if this is unsized.
|
/// Returns the element present in this `ndarray` if this is unsized.
|
||||||
@ -512,11 +466,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> Option<BasicValueEnum<'ctx>> {
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
let Some(is_unsized) = self.is_unsized() else {
|
if self.is_unsized() {
|
||||||
panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_unsized {
|
|
||||||
// NOTE: `np.size(self) == 0` here is never possible.
|
// NOTE: `np.size(self) == 0` here is never possible.
|
||||||
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||||
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||||
@ -534,8 +484,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> ScalarOrNDArray<'ctx> {
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
|
||||||
|
|
||||||
if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) {
|
if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) {
|
||||||
ScalarOrNDArray::Scalar(unsized_elem)
|
ScalarOrNDArray::Scalar(unsized_elem)
|
||||||
} else {
|
} else {
|
||||||
|
@ -26,9 +26,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndmin: u64,
|
ndmin: u64,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
|
let ndims = self.ndims;
|
||||||
|
|
||||||
let ndims = self.ndims.unwrap();
|
|
||||||
|
|
||||||
if ndims < ndmin {
|
if ndims < ndmin {
|
||||||
// Extend the dimensions with np.newaxis.
|
// Extend the dimensions with np.newaxis.
|
||||||
@ -67,13 +65,13 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
// not contiguous but could be reshaped without copying data. Look into how numpy does
|
// not contiguous but could be reshaped without copying data. Look into how numpy does
|
||||||
// it.
|
// it.
|
||||||
|
|
||||||
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims))
|
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims)
|
||||||
.construct_uninitialized(generator, ctx, None);
|
.construct_uninitialized(generator, ctx, None);
|
||||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
|
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
|
||||||
|
|
||||||
// Resolve negative indices
|
// Resolve negative indices
|
||||||
let size = self.size(generator, ctx);
|
let size = self.size(generator, ctx);
|
||||||
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false);
|
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false);
|
||||||
let dst_shape = dst_ndarray.shape();
|
let dst_shape = dst_ndarray.shape();
|
||||||
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
|
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
|
||||||
generator,
|
generator,
|
||||||
@ -121,7 +119,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
axes: Option<PointerValue<'ctx>>,
|
axes: Option<PointerValue<'ctx>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
|
|
||||||
assert!(
|
assert!(
|
||||||
axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into())
|
axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into())
|
||||||
);
|
);
|
||||||
@ -130,7 +127,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None);
|
let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None);
|
||||||
|
|
||||||
let axes = if let Some(axes) = axes {
|
let axes = if let Some(axes) = axes {
|
||||||
let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false);
|
let num_axes = self.llvm_usize.const_int(self.ndims, false);
|
||||||
|
|
||||||
// `axes = nullptr` if `axes` is unspecified.
|
// `axes = nullptr` if `axes` is unspecified.
|
||||||
let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None);
|
let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None);
|
||||||
|
Loading…
Reference in New Issue
Block a user