forked from M-Labs/nac3
core: Implement `ndarray.fill`
This commit is contained in:
parent
3d2abf73c8
commit
96b7f29679
|
@ -323,17 +323,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let (
|
||||
(ndarray_dtype_ty, _),
|
||||
(ndarray_ndims_ty, _),
|
||||
) = if let TypeEnum::TObj { params, .. } = &*primitives.1.get_ty(primitives.0.ndarray) {
|
||||
(
|
||||
params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
|
||||
params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
|
||||
)
|
||||
} else {
|
||||
|
||||
let TypeEnum::TObj {
|
||||
fields: ndarray_fields,
|
||||
params: ndarray_params,
|
||||
..
|
||||
} = &*primitives.1.get_ty(primitives.0.ndarray) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(var_id, ty)| (*ty, *var_id))
|
||||
.unwrap();
|
||||
let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params
|
||||
.iter()
|
||||
.nth(1)
|
||||
.map(|(var_id, ty)| (*ty, *var_id))
|
||||
.unwrap();
|
||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||
|
||||
let top_level_def_list = vec![
|
||||
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
|
||||
PRIMITIVE_DEF_IDS.int32,
|
||||
|
@ -507,12 +517,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
object_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
|
||||
fields: Vec::default(),
|
||||
methods: Vec::default(),
|
||||
methods: vec![
|
||||
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
|
||||
],
|
||||
ancestors: Vec::default(),
|
||||
constructor: None,
|
||||
resolver: None,
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.fill".into(),
|
||||
simple_name: "fill".into(),
|
||||
signature: ndarray_fill_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
|
||||
Ok(None)
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "int32".into(),
|
||||
simple_name: "int32".into(),
|
||||
|
|
|
@ -203,9 +203,25 @@ impl TopLevelComposer {
|
|||
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "value".into(),
|
||||
ty: ndarray_dtype_tvar.0,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret: none,
|
||||
vars: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
}));
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
fields: Mapping::new(),
|
||||
fields: Mapping::from([
|
||||
("fill".into(), (ndarray_fill_fun_ty, true)),
|
||||
]),
|
||||
params: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
|
|
|
@ -375,9 +375,6 @@ fn call_ndarray_empty_impl<'ctx>(
|
|||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||
/// its input.
|
||||
///
|
||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||
/// with the given value (as opposed to all elements within the array).
|
||||
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
|
@ -441,10 +438,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
|||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
||||
/// as its input
|
||||
///
|
||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||
/// with the given value (as opposed to all elements within the array).
|
||||
/// as its input.
|
||||
fn ndarray_fill_indexed<'ctx, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -832,3 +826,56 @@ pub fn gen_ndarray_identity<'ctx>(
|
|||
llvm_usize.const_zero(),
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.fill`.
|
||||
pub fn gen_ndarray_fill<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<(), String> {
|
||||
assert!(obj.is_some());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let this_arg = obj.as_ref().unwrap().1.clone()
|
||||
.to_basic_value_enum(context, generator, this_ty)?
|
||||
.into_pointer_value();
|
||||
let value_ty = fun.0.args[0].ty;
|
||||
let value_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, value_ty)?;
|
||||
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
context,
|
||||
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|
||||
|generator, ctx, _| {
|
||||
let value = if value_arg.is_pointer_value() {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
copy,
|
||||
value_arg.into_pointer_value(),
|
||||
value_arg.get_type().size_of().map(Into::into).unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
copy.into()
|
||||
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
||||
value_arg
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -52,6 +52,14 @@ def test_ndarray_identity():
|
|||
n: ndarray[float, 2] = np_identity(2)
|
||||
consume_ndarray_2(n)
|
||||
|
||||
def test_ndarray_fill():
|
||||
n: ndarray[float, 2] = np_empty([2, 2])
|
||||
n.fill(1.0)
|
||||
output_float64(n[0][0])
|
||||
output_float64(n[0][1])
|
||||
output_float64(n[1][0])
|
||||
output_float64(n[1][1])
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -60,5 +68,6 @@ def run() -> int32:
|
|||
test_ndarray_full()
|
||||
test_ndarray_eye()
|
||||
test_ndarray_identity()
|
||||
test_ndarray_fill()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue