forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: fix nditer

This commit is contained in:
lyken 2024-08-15 15:13:02 +08:00
parent a69a441bdd
commit 1c48d54afa
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
4 changed files with 34 additions and 22 deletions

View File

@ -86,7 +86,7 @@ struct NDIter {
indices[axis] = 0;
// TODO: Can be optimized with backstrides.
element -= strides[axis] * shape[axis];
element -= strides[axis] * (shape[axis] - 1);
} else {
element += strides[axis];
break;

View File

@ -32,7 +32,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into()
} else {
unreachable!()
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
@ -40,28 +40,28 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into()
} else {
unreachable!()
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
@ -192,6 +192,8 @@ impl<'ctx> NDArrayObject<'ctx> {
"eye_ndarray",
);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
@ -218,7 +220,8 @@ impl<'ctx> NDArrayObject<'ctx> {
Ok(())
})
.unwrap();
todo!()
ndarray
}
/// Create an ndarray like `np.identity`.

View File

@ -649,9 +649,12 @@ impl<'ctx> NDArrayObject<'ctx> {
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
// Reolsve negative indices
let size = self.size(generator, ctx);
let new_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, new_ndims, new_shape);
let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
let dst_shape =
dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape");
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();

View File

@ -1556,15 +1556,21 @@ impl<'a> BuiltinBuilder<'a> {
}?;
// Implementation
let sizet_model = IntModel(SizeT);
let nrows =
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
let ncols =
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
let offset =
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
let i32_model = IntModel(Int32);
let nrows = i32_model
.check_value(generator, ctx.ctx, nrows_arg)
.unwrap()
.s_extend_or_bit_cast(generator, ctx, SizeT, "nrows");
let ncols = i32_model
.check_value(generator, ctx.ctx, ncols_arg)
.unwrap()
.s_extend_or_bit_cast(generator, ctx, SizeT, "ncols");
let offset = i32_model
.check_value(generator, ctx.ctx, offset_arg)
.unwrap()
.s_extend_or_bit_cast(generator, ctx, SizeT, "offset");
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndarray = NDArrayObject::from_np_eye(
generator, ctx, dtype, nrows, ncols, offset,
);