forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: checkpoint

This commit is contained in:
lyken 2024-08-15 11:41:33 +08:00
parent 0df2f26c98
commit f8b934096d
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 159 additions and 56 deletions

View File

@ -1738,40 +1738,42 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_ => val.into(), _ => val.into(),
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); todo!()
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); // let llvm_usize = generator.get_size_type(ctx.ctx);
// let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
// passing it to the elementwise codegen function
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
if op == ast::Unaryop::Invert {
ast::Unaryop::Not
} else {
unreachable!(
"ufunc {} not supported for ndarray[bool, N]",
op.op_info().method_name,
)
}
} else {
op
};
let res = numpy::ndarray_elementwise_unaryop_impl( // // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
generator, // // passing it to the elementwise codegen function
ctx, // let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
ndarray_dtype, // if op == ast::Unaryop::Invert {
None, // ast::Unaryop::Not
val, // } else {
|generator, ctx, val| { // unreachable!(
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? // "ufunc {} not supported for ndarray[bool, N]",
.unwrap() // op.op_info().method_name,
.to_basic_value_enum(ctx, generator, ndarray_dtype) // )
}, // }
)?; // } else {
// op
// };
res.as_base_value().into() // let res = numpy::ndarray_elementwise_unaryop_impl(
// generator,
// ctx,
// ndarray_dtype,
// None,
// val,
// |generator, ctx, val| {
// gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
// .unwrap()
// .to_basic_value_enum(ctx, generator, ndarray_dtype)
// },
// )?;
// res.as_base_value().into()
} else { } else {
unimplemented!() unimplemented!()
})) }))

View File

@ -131,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx, ctx,
|_generator, _ctx| Ok(copy.value), |_generator, _ctx| Ok(copy.value),
|generator, ctx| { |generator, ctx| {
let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy let ndarray = ndarray.make_copy(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.instance.value)) Ok(Some(ndarray.instance.value))
}, },
|_generator, _ctx| { |_generator, _ctx| {

View File

@ -99,7 +99,7 @@ impl<'ctx> NDArrayObject<'ctx> {
assert!(ctx.unifier.unioned(dtype, fill_value.dtype)); assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape); let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value.value); ndarray.fill(generator, ctx, fill_value);
ndarray ndarray
} }
@ -177,9 +177,9 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type, dtype: Type,
num_rows: Int<'ctx, SizeT>, nrows: Int<'ctx, SizeT>,
num_cols: Int<'ctx, SizeT>, ncols: Int<'ctx, SizeT>,
diagonal: Int<'ctx, SizeT>, offset: Int<'ctx, SizeT>,
) -> Self { ) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype); let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype); let ndone = ndarray_one_value(generator, ctx, dtype);
@ -188,7 +188,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator, generator,
ctx, ctx,
dtype, dtype,
&[num_rows, num_cols], &[nrows, ncols],
"eye_ndarray", "eye_ndarray",
); );
@ -209,7 +209,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Write to element // Write to element
let be_one = let be_one =
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one"); row_i.add(ctx, offset, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap(); let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx); let p = nditer.get_pointer(generator, ctx);
@ -220,4 +220,16 @@ impl<'ctx> NDArrayObject<'ctx> {
.unwrap(); .unwrap();
todo!() todo!()
} }
/// Create an ndarray like `np.identity`.
pub fn from_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Int<'ctx, SizeT>,
) -> Self {
// Convenient implementation
let offset = IntModel(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::from_np_eye(generator, ctx, dtype, size, size, offset)
}
} }

View File

@ -137,7 +137,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// TODO: Reimplement this? This method does give us the contiguous `data`, but // TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray // this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient. // is allocated. Though this is super convenient.
let data = self.make_clone(generator, ctx, "").instance.get(generator, ctx, |f| f.data, ""); let data = self.make_copy(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, ""); let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data); result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap(); ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -408,11 +408,11 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray ndarray
} }
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over. /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
/// ///
/// The new ndarray will own its data and will be C-contiguous. /// The new ndarray will own its data and will be C-contiguous.
#[must_use] #[must_use]
pub fn make_clone<G: CodeGenerator + ?Sized>( pub fn make_copy<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -597,18 +597,21 @@ impl<'ctx> NDArrayObject<'ctx> {
} }
} }
/// Fill the ndarray with a value. /// Fill the ndarray with a scalar.
/// ///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>( pub fn fill<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
fill_value: BasicValueEnum<'ctx>, scalar: ScalarObject<'ctx>,
) { ) {
// Sanity check on scalar's type.
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx); let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, fill_value).unwrap(); ctx.builder.build_store(p, scalar.value).unwrap();
Ok(()) Ok(())
}) })
.unwrap(); .unwrap();

View File

@ -1029,9 +1029,21 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, _fun, _args, generator| {
gen_ndarray_copy(ctx, &obj, fun, &args, generator) // Parse `self`
.map(|val| Some(val.as_basic_value_enum())) let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(ctx, generator, this_ty)?;
// Implementation
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this);
let copy = this.make_copy(generator, ctx, "np_copy");
Ok(Some(copy.instance.value.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,
@ -1047,7 +1059,27 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; // Parse `self`
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(ctx, generator, this_ty)?;
// Parse `value`
let value_ty = fun.0.args[0].ty;
let value_arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
// Implementation
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this);
let value = ScalarObject { value: value_arg, dtype: value_ty };
this.fill(generator, ctx, value);
Ok(None) Ok(None)
}, },
)))), )))),
@ -1373,11 +1405,13 @@ impl<'a> BuiltinBuilder<'a> {
assert!(obj.is_none()); assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3)); assert!(matches!(args.len(), 1..=3));
// Parse argument `object`
let object_ty = fun.0.args[0].ty; let object_ty = fun.0.args[0].ty;
let object = let object =
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?; args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
let object = AnyObject { ty: object_ty, value: object }; let object = AnyObject { ty: object_ty, value: object };
// Parse argument `copy`
let copy_arg = if let Some(arg) = args let copy_arg = if let Some(arg) = args
.iter() .iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) .find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
@ -1395,6 +1429,7 @@ impl<'a> BuiltinBuilder<'a> {
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value. // The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
// We simply make the output ndarray's ndims correct with `atleast_nd`. // We simply make the output ndarray's ndims correct with `atleast_nd`.
// Implementation
let (dtype, ndims) = let (dtype, ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let output_ndims = extract_ndims(&ctx.unifier, ndims); let output_ndims = extract_ndims(&ctx.unifier, ndims);
@ -1431,14 +1466,14 @@ impl<'a> BuiltinBuilder<'a> {
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 2); assert_eq!(args.len(), 2);
// Parse argument #1 shape // Parse argument `shape`
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape = let shape =
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape }; let shape = AnyObject { ty: shape_ty, value: shape };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape); let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
// Parse argument #2 fill_value // Parse argument `fill_value`
let fill_value_ty = fun.0.args[1].ty; let fill_value_ty = fun.0.args[1].ty;
let fill_value = let fill_value =
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?; args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
@ -1490,9 +1525,53 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, _obj, fun, args, generator| {
gen_ndarray_eye(ctx, &obj, fun, &args, generator) // Parse argument `N`
.map(|val| Some(val.as_basic_value_enum())) let nrows_ty = fun.0.args[0].ty;
let nrows_arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)?;
// Parse argument `M`
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
arg.1.clone().to_basic_value_enum(ctx, generator, ncols_ty)
} else {
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)
}?;
// Parse argument `k`
let offset_ty = fun.0.args[2].ty;
let offset_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
arg.1.clone().to_basic_value_enum(ctx, generator, offset_ty)
} else {
Ok(ctx.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty,
))
}?;
// Implementation
let sizet_model = IntModel(SizeT);
let nrows =
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
let ncols =
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
let offset =
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndarray = NDArrayObject::from_np_eye(
generator, ctx, dtype, nrows, ncols, offset,
);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,
@ -1504,9 +1583,16 @@ impl<'a> BuiltinBuilder<'a> {
prim.name(), prim.name(),
self.ndarray_float_2d, self.ndarray_float_2d,
&[(int32, "n")], &[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, _obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator) // Parse argument `n`
.map(|val| Some(val.as_basic_value_enum())) let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let n = IntModel(SizeT).check_value(generator, ctx.ctx, n_arg).unwrap();
// Implementation
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndarray = NDArrayObject::from_np_identity(generator, ctx, dtype, n);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}), }),
), ),
PrimDef::FunNpArange => { PrimDef::FunNpArange => {
@ -1841,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { value: arg, ty: arg_ty }; let arg = AnyObject { value: arg, ty: arg_ty };
Ok(Some(arg.len(generator, ctx).value)) Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,