1
0
forked from M-Labs/nac3

WIP: core/ndstrides: checkpoint

This commit is contained in:
lyken 2024-08-15 11:41:33 +08:00
parent 0df2f26c98
commit f8b934096d
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 159 additions and 56 deletions

View File

@ -1738,40 +1738,42 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_ => val.into(),
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
todo!()
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
if op == ast::Unaryop::Invert {
ast::Unaryop::Not
} else {
unreachable!(
"ufunc {} not supported for ndarray[bool, N]",
op.op_info().method_name,
)
}
} else {
op
};
// let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
let res = numpy::ndarray_elementwise_unaryop_impl(
generator,
ctx,
ndarray_dtype,
None,
val,
|generator, ctx, val| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
.unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;
// // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// // passing it to the elementwise codegen function
// let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
// if op == ast::Unaryop::Invert {
// ast::Unaryop::Not
// } else {
// unreachable!(
// "ufunc {} not supported for ndarray[bool, N]",
// op.op_info().method_name,
// )
// }
// } else {
// op
// };
res.as_base_value().into()
// let res = numpy::ndarray_elementwise_unaryop_impl(
// generator,
// ctx,
// ndarray_dtype,
// None,
// val,
// |generator, ctx, val| {
// gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
// .unwrap()
// .to_basic_value_enum(ctx, generator, ndarray_dtype)
// },
// )?;
// res.as_base_value().into()
} else {
unimplemented!()
}))

View File

@ -131,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy
let ndarray = ndarray.make_copy(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {

View File

@ -99,7 +99,7 @@ impl<'ctx> NDArrayObject<'ctx> {
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value.value);
ndarray.fill(generator, ctx, fill_value);
ndarray
}
@ -177,9 +177,9 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
num_rows: Int<'ctx, SizeT>,
num_cols: Int<'ctx, SizeT>,
diagonal: Int<'ctx, SizeT>,
nrows: Int<'ctx, SizeT>,
ncols: Int<'ctx, SizeT>,
offset: Int<'ctx, SizeT>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
@ -188,7 +188,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator,
ctx,
dtype,
&[num_rows, num_cols],
&[nrows, ncols],
"eye_ndarray",
);
@ -209,7 +209,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Write to element
let be_one =
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
row_i.add(ctx, offset, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
@ -220,4 +220,16 @@ impl<'ctx> NDArrayObject<'ctx> {
.unwrap();
todo!()
}
/// Create an ndarray like `np.identity`.
pub fn from_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Int<'ctx, SizeT>,
) -> Self {
// Convenient implementation
let offset = IntModel(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::from_np_eye(generator, ctx, dtype, size, size, offset)
}
}

View File

@ -137,7 +137,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient.
let data = self.make_clone(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = self.make_copy(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -408,11 +408,11 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray
}
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
///
/// The new ndarray will own its data and will be C-contiguous.
#[must_use]
pub fn make_clone<G: CodeGenerator + ?Sized>(
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -597,18 +597,21 @@ impl<'ctx> NDArrayObject<'ctx> {
}
}
/// Fill the ndarray with a value.
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
fill_value: BasicValueEnum<'ctx>,
scalar: ScalarObject<'ctx>,
) {
// Sanity check on scalar's type.
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, fill_value).unwrap();
ctx.builder.build_store(p, scalar.value).unwrap();
Ok(())
})
.unwrap();

View File

@ -1029,9 +1029,21 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_copy(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
|ctx, obj, _fun, _args, generator| {
// Parse `self`
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(ctx, generator, this_ty)?;
// Implementation
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this);
let copy = this.make_copy(generator, ctx, "np_copy");
Ok(Some(copy.instance.value.as_basic_value_enum()))
},
)))),
loc: None,
@ -1047,7 +1059,27 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
// Parse `self`
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(ctx, generator, this_ty)?;
// Parse `value`
let value_ty = fun.0.args[0].ty;
let value_arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
// Implementation
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this);
let value = ScalarObject { value: value_arg, dtype: value_ty };
this.fill(generator, ctx, value);
Ok(None)
},
)))),
@ -1373,11 +1405,13 @@ impl<'a> BuiltinBuilder<'a> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
// Parse argument `object`
let object_ty = fun.0.args[0].ty;
let object =
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
let object = AnyObject { ty: object_ty, value: object };
// Parse argument `copy`
let copy_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
@ -1395,6 +1429,7 @@ impl<'a> BuiltinBuilder<'a> {
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
// We simply make the output ndarray's ndims correct with `atleast_nd`.
// Implementation
let (dtype, ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let output_ndims = extract_ndims(&ctx.unifier, ndims);
@ -1431,14 +1466,14 @@ impl<'a> BuiltinBuilder<'a> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
// Parse argument `shape`
let shape_ty = fun.0.args[0].ty;
let shape =
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
// Parse argument #2 fill_value
// Parse argument `fill_value`
let fill_value_ty = fun.0.args[1].ty;
let fill_value =
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
@ -1490,9 +1525,53 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_eye(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
|ctx, _obj, fun, args, generator| {
// Parse argument `N`
let nrows_ty = fun.0.args[0].ty;
let nrows_arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)?;
// Parse argument `M`
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
arg.1.clone().to_basic_value_enum(ctx, generator, ncols_ty)
} else {
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)
}?;
// Parse argument `k`
let offset_ty = fun.0.args[2].ty;
let offset_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
arg.1.clone().to_basic_value_enum(ctx, generator, offset_ty)
} else {
Ok(ctx.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty,
))
}?;
// Implementation
let sizet_model = IntModel(SizeT);
let nrows =
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
let ncols =
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
let offset =
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndarray = NDArrayObject::from_np_eye(
generator, ctx, dtype, nrows, ncols, offset,
);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
},
)))),
loc: None,
@ -1504,9 +1583,16 @@ impl<'a> BuiltinBuilder<'a> {
prim.name(),
self.ndarray_float_2d,
&[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
Box::new(|ctx, _obj, fun, args, generator| {
// Parse argument `n`
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let n = IntModel(SizeT).check_value(generator, ctx.ctx, n_arg).unwrap();
// Implementation
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndarray = NDArrayObject::from_np_identity(generator, ctx, dtype, n);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}),
),
PrimDef::FunNpArange => {
@ -1841,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { value: arg, ty: arg_ty };
Ok(Some(arg.len(generator, ctx).value))
Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
},
)))),
loc: None,