forked from M-Labs/nac3
1
0
Fork 0

core: Implement `ndarray.fill`

This commit is contained in:
David Mak 2024-03-06 16:53:41 +08:00
parent 3d2abf73c8
commit 96b7f29679
4 changed files with 119 additions and 19 deletions

View File

@ -323,17 +323,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
} else {
unreachable!()
};
let (
(ndarray_dtype_ty, _),
(ndarray_ndims_ty, _),
) = if let TypeEnum::TObj { params, .. } = &*primitives.1.get_ty(primitives.0.ndarray) {
(
params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
)
} else {
let TypeEnum::TObj {
fields: ndarray_fields,
params: ndarray_params,
..
} = &*primitives.1.get_ty(primitives.0.ndarray) else {
unreachable!()
};
let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params
.iter()
.next()
.map(|(var_id, ty)| (*ty, *var_id))
.unwrap();
let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params
.iter()
.nth(1)
.map(|(var_id, ty)| (*ty, *var_id))
.unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
PRIMITIVE_DEF_IDS.int32,
@ -507,12 +517,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
object_id: PRIMITIVE_DEF_IDS.ndarray,
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
fields: Vec::default(),
methods: Vec::default(),
methods: vec![
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
],
ancestors: Vec::default(),
constructor: None,
resolver: None,
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.fill".into(),
simple_name: "fill".into(),
signature: ndarray_fill_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
Ok(None)
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "int32".into(),
simple_name: "int32".into(),

View File

@ -203,9 +203,25 @@ impl TopLevelComposer {
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "value".into(),
ty: ndarray_dtype_tvar.0,
default_value: None,
},
],
ret: none,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::new(),
fields: Mapping::from([
("fill".into(), (ndarray_fill_fun_ty, true)),
]),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),

View File

@ -375,9 +375,6 @@ fn call_ndarray_empty_impl<'ctx>(
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
///
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
/// with the given value (as opposed to all elements within the array).
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
@ -441,10 +438,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input
///
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
/// with the given value (as opposed to all elements within the array).
/// as its input.
fn ndarray_fill_indexed<'ctx, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -832,3 +826,56 @@ pub fn gen_ndarray_identity<'ctx>(
llvm_usize.const_zero(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.fill`.
pub fn gen_ndarray_fill<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj.as_ref().unwrap().1.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(())
}

View File

@ -52,6 +52,14 @@ def test_ndarray_identity():
n: ndarray[float, 2] = np_identity(2)
consume_ndarray_2(n)
def test_ndarray_fill():
n: ndarray[float, 2] = np_empty([2, 2])
n.fill(1.0)
output_float64(n[0][0])
output_float64(n[0][1])
output_float64(n[1][0])
output_float64(n[1][1])
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -60,5 +68,6 @@ def run() -> int32:
test_ndarray_full()
test_ndarray_eye()
test_ndarray_identity()
test_ndarray_fill()
return 0