forked from M-Labs/nac3
WIP: core/ndstrides: fix nditer
This commit is contained in:
parent
a69a441bdd
commit
1c48d54afa
|
@ -86,7 +86,7 @@ struct NDIter {
|
|||
indices[axis] = 0;
|
||||
|
||||
// TODO: Can be optimized with backstrides.
|
||||
element -= strides[axis] * shape[axis];
|
||||
element -= strides[axis] * (shape[axis] - 1);
|
||||
} else {
|
||||
element += strides[axis];
|
||||
break;
|
||||
|
|
|
@ -32,7 +32,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,28 +40,28 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,6 +192,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
"eye_ndarray",
|
||||
);
|
||||
|
||||
// Create data and make the matrix like look np.eye()
|
||||
ndarray.create_data(generator, ctx);
|
||||
ndarray
|
||||
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||
|
@ -218,7 +220,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
todo!()
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.identity`.
|
||||
|
|
|
@ -649,9 +649,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
|
||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
||||
|
||||
// Reolsve negative indices
|
||||
let size = self.size(generator, ctx);
|
||||
let new_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
|
||||
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, new_ndims, new_shape);
|
||||
let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
|
||||
let dst_shape =
|
||||
dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape");
|
||||
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape);
|
||||
|
||||
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||
|
|
|
@ -1556,15 +1556,21 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}?;
|
||||
|
||||
// Implementation
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let nrows =
|
||||
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
|
||||
let ncols =
|
||||
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
|
||||
let offset =
|
||||
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
|
||||
let i32_model = IntModel(Int32);
|
||||
let nrows = i32_model
|
||||
.check_value(generator, ctx.ctx, nrows_arg)
|
||||
.unwrap()
|
||||
.s_extend_or_bit_cast(generator, ctx, SizeT, "nrows");
|
||||
let ncols = i32_model
|
||||
.check_value(generator, ctx.ctx, ncols_arg)
|
||||
.unwrap()
|
||||
.s_extend_or_bit_cast(generator, ctx, SizeT, "ncols");
|
||||
let offset = i32_model
|
||||
.check_value(generator, ctx.ctx, offset_arg)
|
||||
.unwrap()
|
||||
.s_extend_or_bit_cast(generator, ctx, SizeT, "offset");
|
||||
|
||||
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndarray = NDArrayObject::from_np_eye(
|
||||
generator, ctx, dtype, nrows, ncols, offset,
|
||||
);
|
||||
|
|
Loading…
Reference in New Issue