forked from M-Labs/nac3
core: Implement `ndarray.fill`
This commit is contained in:
parent
3d2abf73c8
commit
96b7f29679
|
@ -323,17 +323,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
let (
|
|
||||||
(ndarray_dtype_ty, _),
|
let TypeEnum::TObj {
|
||||||
(ndarray_ndims_ty, _),
|
fields: ndarray_fields,
|
||||||
) = if let TypeEnum::TObj { params, .. } = &*primitives.1.get_ty(primitives.0.ndarray) {
|
params: ndarray_params,
|
||||||
(
|
..
|
||||||
params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
|
} = &*primitives.1.get_ty(primitives.0.ndarray) else {
|
||||||
params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params
|
||||||
|
.iter()
|
||||||
|
.next()
|
||||||
|
.map(|(var_id, ty)| (*ty, *var_id))
|
||||||
|
.unwrap();
|
||||||
|
let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params
|
||||||
|
.iter()
|
||||||
|
.nth(1)
|
||||||
|
.map(|(var_id, ty)| (*ty, *var_id))
|
||||||
|
.unwrap();
|
||||||
|
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||||
|
|
||||||
let top_level_def_list = vec![
|
let top_level_def_list = vec![
|
||||||
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
|
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
|
||||||
PRIMITIVE_DEF_IDS.int32,
|
PRIMITIVE_DEF_IDS.int32,
|
||||||
|
@ -507,12 +517,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
object_id: PRIMITIVE_DEF_IDS.ndarray,
|
object_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||||
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
|
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: vec![
|
||||||
|
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
|
||||||
|
],
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
constructor: None,
|
constructor: None,
|
||||||
resolver: None,
|
resolver: None,
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
|
name: "ndarray.fill".into(),
|
||||||
|
simple_name: "fill".into(),
|
||||||
|
signature: ndarray_fill_ty.0,
|
||||||
|
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||||
|
instance_to_symbol: HashMap::default(),
|
||||||
|
instance_to_stmt: HashMap::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|
|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
)))),
|
||||||
|
loc: None,
|
||||||
|
})),
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "int32".into(),
|
name: "int32".into(),
|
||||||
simple_name: "int32".into(),
|
simple_name: "int32".into(),
|
||||||
|
|
|
@ -203,9 +203,25 @@ impl TopLevelComposer {
|
||||||
|
|
||||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||||
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||||
|
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "value".into(),
|
||||||
|
ty: ndarray_dtype_tvar.0,
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret: none,
|
||||||
|
vars: VarMap::from([
|
||||||
|
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||||
|
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||||
|
]),
|
||||||
|
}));
|
||||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||||
fields: Mapping::new(),
|
fields: Mapping::from([
|
||||||
|
("fill".into(), (ndarray_fill_fun_ty, true)),
|
||||||
|
]),
|
||||||
params: VarMap::from([
|
params: VarMap::from([
|
||||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||||
|
|
|
@ -375,9 +375,6 @@ fn call_ndarray_empty_impl<'ctx>(
|
||||||
|
|
||||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||||
/// its input.
|
/// its input.
|
||||||
///
|
|
||||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
|
||||||
/// with the given value (as opposed to all elements within the array).
|
|
||||||
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
@ -441,10 +438,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
||||||
/// as its input
|
/// as its input.
|
||||||
///
|
|
||||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
|
||||||
/// with the given value (as opposed to all elements within the array).
|
|
||||||
fn ndarray_fill_indexed<'ctx, ValueFn>(
|
fn ndarray_fill_indexed<'ctx, ValueFn>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -831,4 +825,57 @@ pub fn gen_ndarray_identity<'ctx>(
|
||||||
n_arg.into_int_value(),
|
n_arg.into_int_value(),
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
).map(NDArrayValue::into)
|
).map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.fill`.
|
||||||
|
pub fn gen_ndarray_fill<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
assert!(obj.is_some());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
|
let this_arg = obj.as_ref().unwrap().1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, this_ty)?
|
||||||
|
.into_pointer_value();
|
||||||
|
let value_ty = fun.0.args[0].ty;
|
||||||
|
let value_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, value_ty)?;
|
||||||
|
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|
||||||
|
|generator, ctx, _| {
|
||||||
|
let value = if value_arg.is_pointer_value() {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
|
||||||
|
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
|
||||||
|
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
copy,
|
||||||
|
value_arg.into_pointer_value(),
|
||||||
|
value_arg.get_type().size_of().map(Into::into).unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
copy.into()
|
||||||
|
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
||||||
|
value_arg
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
|
@ -52,6 +52,14 @@ def test_ndarray_identity():
|
||||||
n: ndarray[float, 2] = np_identity(2)
|
n: ndarray[float, 2] = np_identity(2)
|
||||||
consume_ndarray_2(n)
|
consume_ndarray_2(n)
|
||||||
|
|
||||||
|
def test_ndarray_fill():
|
||||||
|
n: ndarray[float, 2] = np_empty([2, 2])
|
||||||
|
n.fill(1.0)
|
||||||
|
output_float64(n[0][0])
|
||||||
|
output_float64(n[0][1])
|
||||||
|
output_float64(n[1][0])
|
||||||
|
output_float64(n[1][1])
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
@ -60,5 +68,6 @@ def run() -> int32:
|
||||||
test_ndarray_full()
|
test_ndarray_full()
|
||||||
test_ndarray_eye()
|
test_ndarray_eye()
|
||||||
test_ndarray_identity()
|
test_ndarray_identity()
|
||||||
|
test_ndarray_fill()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue