From 96b7f29679b5c7e485bc40359668fab701c6faea Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 6 Mar 2024 16:53:41 +0800 Subject: [PATCH] core: Implement `ndarray.fill` --- nac3core/src/toplevel/builtins.rs | 48 ++++++++++++++++++----- nac3core/src/toplevel/helper.rs | 18 ++++++++- nac3core/src/toplevel/numpy.rs | 63 ++++++++++++++++++++++++++---- nac3standalone/demo/src/ndarray.py | 9 +++++ 4 files changed, 119 insertions(+), 19 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 03cd922d..0f7ebd02 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -323,17 +323,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { } else { unreachable!() }; - let ( - (ndarray_dtype_ty, _), - (ndarray_ndims_ty, _), - ) = if let TypeEnum::TObj { params, .. } = &*primitives.1.get_ty(primitives.0.ndarray) { - ( - params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(), - params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(), - ) - } else { + + let TypeEnum::TObj { + fields: ndarray_fields, + params: ndarray_params, + .. + } = &*primitives.1.get_ty(primitives.0.ndarray) else { unreachable!() }; + + let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params + .iter() + .next() + .map(|(var_id, ty)| (*ty, *var_id)) + .unwrap(); + let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params + .iter() + .nth(1) + .map(|(var_id, ty)| (*ty, *var_id)) + .unwrap(); + let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( PRIMITIVE_DEF_IDS.int32, @@ -507,12 +517,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { object_id: PRIMITIVE_DEF_IDS.ndarray, type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty], fields: Vec::default(), - methods: Vec::default(), + methods: vec![ + ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), + ], ancestors: Vec::default(), constructor: None, resolver: None, loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.fill".into(), + simple_name: "fill".into(), + signature: ndarray_fill_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; + Ok(None) + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 679b6c1c..dcd3385f 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -203,9 +203,25 @@ impl TopLevelComposer { let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "value".into(), + ty: ndarray_dtype_tvar.0, + default_value: None, + }, + ], + ret: none, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, - fields: Mapping::new(), + fields: Mapping::from([ + ("fill".into(), (ndarray_fill_fun_ty, true)), + ]), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 26c8044d..7521291b 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -375,9 +375,6 @@ fn call_ndarray_empty_impl<'ctx>( /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. -/// -/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements -/// with the given value (as opposed to all elements within the array). fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -441,10 +438,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( } /// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input -/// -/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements -/// with the given value (as opposed to all elements within the array). +/// as its input. fn ndarray_fill_indexed<'ctx, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, @@ -831,4 +825,57 @@ pub fn gen_ndarray_identity<'ctx>( n_arg.into_int_value(), llvm_usize.const_zero(), ).map(NDArrayValue::into) -} \ No newline at end of file +} + +/// Generates LLVM IR for `ndarray.fill`. +pub fn gen_ndarray_fill<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result<(), String> { + assert!(obj.is_some()); + assert_eq!(args.len(), 1); + + let llvm_usize = generator.get_size_type(context.ctx); + + let this_ty = obj.as_ref().unwrap().0; + let this_arg = obj.as_ref().unwrap().1.clone() + .to_basic_value_enum(context, generator, this_ty)? + .into_pointer_value(); + let value_ty = fun.0.args[0].ty; + let value_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, value_ty)?; + + ndarray_fill_flattened( + generator, + context, + NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), + |generator, ctx, _| { + let value = if value_arg.is_pointer_value() { + let llvm_i1 = ctx.ctx.bool_type(); + + let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; + + call_memcpy_generic( + ctx, + copy, + value_arg.into_pointer_value(), + value_arg.get_type().size_of().map(Into::into).unwrap(), + llvm_i1.const_zero(), + ); + + copy.into() + } else if value_arg.is_int_value() || value_arg.is_float_value() { + value_arg + } else { + unreachable!() + }; + + Ok(value) + } + )?; + + Ok(()) +} diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 38abd9f9..16eec742 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -52,6 +52,14 @@ def test_ndarray_identity(): n: ndarray[float, 2] = np_identity(2) consume_ndarray_2(n) +def test_ndarray_fill(): + n: ndarray[float, 2] = np_empty([2, 2]) + n.fill(1.0) + output_float64(n[0][0]) + output_float64(n[0][1]) + output_float64(n[1][0]) + output_float64(n[1][1]) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -60,5 +68,6 @@ def run() -> int32: test_ndarray_full() test_ndarray_eye() test_ndarray_identity() + test_ndarray_fill() return 0