forked from M-Labs/nac3
WIP: core/ndstrides: fix nditer
This commit is contained in:
parent
a69a441bdd
commit
1c48d54afa
|
@ -86,7 +86,7 @@ struct NDIter {
|
||||||
indices[axis] = 0;
|
indices[axis] = 0;
|
||||||
|
|
||||||
// TODO: Can be optimized with backstrides.
|
// TODO: Can be optimized with backstrides.
|
||||||
element -= strides[axis] * shape[axis];
|
element -= strides[axis] * (shape[axis] - 1);
|
||||||
} else {
|
} else {
|
||||||
element += strides[axis];
|
element += strides[axis];
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -32,7 +32,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
ctx.gen_string(generator, "").value.into()
|
ctx.gen_string(generator, "").value.into()
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,28 +40,28 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
dtype: Type,
|
||||||
) -> BasicValueEnum<'ctx> {
|
) -> BasicValueEnum<'ctx> {
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||||
.iter()
|
.iter()
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
{
|
{
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||||
.iter()
|
.iter()
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
{
|
{
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||||
ctx.ctx.f64_type().const_float(1.0).into()
|
ctx.ctx.f64_type().const_float(1.0).into()
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||||
ctx.ctx.bool_type().const_int(1, false).into()
|
ctx.ctx.bool_type().const_int(1, false).into()
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
ctx.gen_string(generator, "1").value.into()
|
ctx.gen_string(generator, "1").value.into()
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,6 +192,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
"eye_ndarray",
|
"eye_ndarray",
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Create data and make the matrix like look np.eye()
|
||||||
|
ndarray.create_data(generator, ctx);
|
||||||
ndarray
|
ndarray
|
||||||
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||||
|
@ -218,7 +220,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
todo!()
|
|
||||||
|
ndarray
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an ndarray like `np.identity`.
|
/// Create an ndarray like `np.identity`.
|
||||||
|
|
|
@ -649,9 +649,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
|
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
|
||||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
||||||
|
|
||||||
|
// Reolsve negative indices
|
||||||
let size = self.size(generator, ctx);
|
let size = self.size(generator, ctx);
|
||||||
let new_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
|
let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
|
||||||
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, new_ndims, new_shape);
|
let dst_shape =
|
||||||
|
dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape");
|
||||||
|
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape);
|
||||||
|
|
||||||
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
||||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||||
|
|
|
@ -1556,15 +1556,21 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let sizet_model = IntModel(SizeT);
|
let i32_model = IntModel(Int32);
|
||||||
let nrows =
|
let nrows = i32_model
|
||||||
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
|
.check_value(generator, ctx.ctx, nrows_arg)
|
||||||
let ncols =
|
.unwrap()
|
||||||
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
|
.s_extend_or_bit_cast(generator, ctx, SizeT, "nrows");
|
||||||
let offset =
|
let ncols = i32_model
|
||||||
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
|
.check_value(generator, ctx.ctx, ncols_arg)
|
||||||
|
.unwrap()
|
||||||
|
.s_extend_or_bit_cast(generator, ctx, SizeT, "ncols");
|
||||||
|
let offset = i32_model
|
||||||
|
.check_value(generator, ctx.ctx, offset_arg)
|
||||||
|
.unwrap()
|
||||||
|
.s_extend_or_bit_cast(generator, ctx, SizeT, "offset");
|
||||||
|
|
||||||
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
let ndarray = NDArrayObject::from_np_eye(
|
let ndarray = NDArrayObject::from_np_eye(
|
||||||
generator, ctx, dtype, nrows, ncols, offset,
|
generator, ctx, dtype, nrows, ncols, offset,
|
||||||
);
|
);
|
||||||
|
|
Loading…
Reference in New Issue