forked from M-Labs/nac3
WIP: core/ndstrides: checkpoint
This commit is contained in:
parent
0df2f26c98
commit
f8b934096d
@ -1738,40 +1738,42 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
todo!()
|
||||
|
||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
||||
// let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
// let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||
// passing it to the elementwise codegen function
|
||||
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
if op == ast::Unaryop::Invert {
|
||||
ast::Unaryop::Not
|
||||
} else {
|
||||
unreachable!(
|
||||
"ufunc {} not supported for ndarray[bool, N]",
|
||||
op.op_info().method_name,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
op
|
||||
};
|
||||
// let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
None,
|
||||
val,
|
||||
|generator, ctx, val| {
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
},
|
||||
)?;
|
||||
// // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||
// // passing it to the elementwise codegen function
|
||||
// let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
// if op == ast::Unaryop::Invert {
|
||||
// ast::Unaryop::Not
|
||||
// } else {
|
||||
// unreachable!(
|
||||
// "ufunc {} not supported for ndarray[bool, N]",
|
||||
// op.op_info().method_name,
|
||||
// )
|
||||
// }
|
||||
// } else {
|
||||
// op
|
||||
// };
|
||||
|
||||
res.as_base_value().into()
|
||||
// let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||
// generator,
|
||||
// ctx,
|
||||
// ndarray_dtype,
|
||||
// None,
|
||||
// val,
|
||||
// |generator, ctx, val| {
|
||||
// gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
||||
// .unwrap()
|
||||
// .to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
// },
|
||||
// )?;
|
||||
|
||||
// res.as_base_value().into()
|
||||
} else {
|
||||
unimplemented!()
|
||||
}))
|
||||
|
@ -131,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ctx,
|
||||
|_generator, _ctx| Ok(copy.value),
|
||||
|generator, ctx| {
|
||||
let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy
|
||||
let ndarray = ndarray.make_copy(generator, ctx, "np_array_copied_ndarray"); // Force copy
|
||||
Ok(Some(ndarray.instance.value))
|
||||
},
|
||||
|_generator, _ctx| {
|
||||
|
@ -99,7 +99,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
|
||||
|
||||
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
|
||||
ndarray.fill(generator, ctx, fill_value.value);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
ndarray
|
||||
}
|
||||
|
||||
@ -177,9 +177,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
num_rows: Int<'ctx, SizeT>,
|
||||
num_cols: Int<'ctx, SizeT>,
|
||||
diagonal: Int<'ctx, SizeT>,
|
||||
nrows: Int<'ctx, SizeT>,
|
||||
ncols: Int<'ctx, SizeT>,
|
||||
offset: Int<'ctx, SizeT>,
|
||||
) -> Self {
|
||||
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||
@ -188,7 +188,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
&[num_rows, num_cols],
|
||||
&[nrows, ncols],
|
||||
"eye_ndarray",
|
||||
);
|
||||
|
||||
@ -209,7 +209,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
|
||||
// Write to element
|
||||
let be_one =
|
||||
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
|
||||
row_i.add(ctx, offset, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
|
||||
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
|
||||
|
||||
let p = nditer.get_pointer(generator, ctx);
|
||||
@ -220,4 +220,16 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
.unwrap();
|
||||
todo!()
|
||||
}
|
||||
|
||||
/// Create an ndarray like `np.identity`.
|
||||
pub fn from_np_identity<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
size: Int<'ctx, SizeT>,
|
||||
) -> Self {
|
||||
// Convenient implementation
|
||||
let offset = IntModel(SizeT).const_0(generator, ctx.ctx);
|
||||
NDArrayObject::from_np_eye(generator, ctx, dtype, size, size, offset)
|
||||
}
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// TODO: Reimplement this? This method does give us the contiguous `data`, but
|
||||
// this creates a few extra bytes of useless information because an entire NDArray
|
||||
// is allocated. Though this is super convenient.
|
||||
let data = self.make_clone(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
|
||||
let data = self.make_copy(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
|
||||
let data = data.pointer_cast(generator, ctx, item_model, "");
|
||||
result.set(ctx, |f| f.data, data);
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
@ -408,11 +408,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
|
||||
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
|
||||
///
|
||||
/// The new ndarray will own its data and will be C-contiguous.
|
||||
#[must_use]
|
||||
pub fn make_clone<G: CodeGenerator + ?Sized>(
|
||||
pub fn make_copy<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -597,18 +597,21 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the ndarray with a value.
|
||||
/// Fill the ndarray with a scalar.
|
||||
///
|
||||
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
||||
pub fn fill<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
) {
|
||||
// Sanity check on scalar's type.
|
||||
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
|
||||
|
||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
let p = nditer.get_pointer(generator, ctx);
|
||||
ctx.builder.build_store(p, fill_value).unwrap();
|
||||
ctx.builder.build_store(p, scalar.value).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -1029,9 +1029,21 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_copy(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
|ctx, obj, _fun, _args, generator| {
|
||||
// Parse `self`
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let this_arg = obj
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, this_ty)?;
|
||||
|
||||
// Implementation
|
||||
let this = AnyObject { value: this_arg, ty: this_ty };
|
||||
let this = NDArrayObject::from_object(generator, ctx, this);
|
||||
let copy = this.make_copy(generator, ctx, "np_copy");
|
||||
Ok(Some(copy.instance.value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
@ -1047,7 +1059,27 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
|
||||
// Parse `self`
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let this_arg = obj
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, this_ty)?;
|
||||
|
||||
// Parse `value`
|
||||
let value_ty = fun.0.args[0].ty;
|
||||
let value_arg =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
|
||||
|
||||
// Implementation
|
||||
let this = AnyObject { value: this_arg, ty: this_ty };
|
||||
let this = NDArrayObject::from_object(generator, ctx, this);
|
||||
|
||||
let value = ScalarObject { value: value_arg, dtype: value_ty };
|
||||
this.fill(generator, ctx, value);
|
||||
|
||||
Ok(None)
|
||||
},
|
||||
)))),
|
||||
@ -1373,11 +1405,13 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1..=3));
|
||||
|
||||
// Parse argument `object`
|
||||
let object_ty = fun.0.args[0].ty;
|
||||
let object =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||
let object = AnyObject { ty: object_ty, value: object };
|
||||
|
||||
// Parse argument `copy`
|
||||
let copy_arg = if let Some(arg) = args
|
||||
.iter()
|
||||
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||
@ -1395,6 +1429,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
|
||||
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
||||
|
||||
// Implementation
|
||||
let (dtype, ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
@ -1431,14 +1466,14 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
// Parse argument #1 shape
|
||||
// Parse argument `shape`
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
// Parse argument #2 fill_value
|
||||
// Parse argument `fill_value`
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
let fill_value =
|
||||
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
||||
@ -1490,9 +1525,53 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_eye(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
|ctx, _obj, fun, args, generator| {
|
||||
// Parse argument `N`
|
||||
let nrows_ty = fun.0.args[0].ty;
|
||||
let nrows_arg =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)?;
|
||||
|
||||
// Parse argument `M`
|
||||
let ncols_ty = fun.0.args[1].ty;
|
||||
let ncols_arg = if let Some(arg) = args
|
||||
.iter()
|
||||
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||
{
|
||||
arg.1.clone().to_basic_value_enum(ctx, generator, ncols_ty)
|
||||
} else {
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, nrows_ty)
|
||||
}?;
|
||||
|
||||
// Parse argument `k`
|
||||
let offset_ty = fun.0.args[2].ty;
|
||||
let offset_arg = if let Some(arg) = args
|
||||
.iter()
|
||||
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
||||
{
|
||||
arg.1.clone().to_basic_value_enum(ctx, generator, offset_ty)
|
||||
} else {
|
||||
Ok(ctx.gen_symbol_val(
|
||||
generator,
|
||||
fun.0.args[2].default_value.as_ref().unwrap(),
|
||||
offset_ty,
|
||||
))
|
||||
}?;
|
||||
|
||||
// Implementation
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let nrows =
|
||||
sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap();
|
||||
let ncols =
|
||||
sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap();
|
||||
let offset =
|
||||
sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap();
|
||||
|
||||
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndarray = NDArrayObject::from_np_eye(
|
||||
generator, ctx, dtype, nrows, ncols, offset,
|
||||
);
|
||||
|
||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
@ -1504,9 +1583,16 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
&[(int32, "n")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
Box::new(|ctx, _obj, fun, args, generator| {
|
||||
// Parse argument `n`
|
||||
let n_ty = fun.0.args[0].ty;
|
||||
let n_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
||||
let n = IntModel(SizeT).check_value(generator, ctx.ctx, n_arg).unwrap();
|
||||
|
||||
// Implementation
|
||||
let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndarray = NDArrayObject::from_np_identity(generator, ctx, dtype, n);
|
||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
PrimDef::FunNpArange => {
|
||||
@ -1841,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { value: arg, ty: arg_ty };
|
||||
|
||||
Ok(Some(arg.len(generator, ctx).value))
|
||||
Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
Loading…
Reference in New Issue
Block a user