forked from M-Labs/nac3
1
0
Fork 0

core: Implement `ndarray.copy`

This commit is contained in:
David Mak 2024-03-07 13:02:13 +08:00
parent a94927a11d
commit c3b122acfc
8 changed files with 118 additions and 8 deletions

View File

@ -342,6 +342,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.nth(1)
.map(|(var_id, ty)| (*ty, *var_id))
.unwrap();
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let top_level_def_list = vec![
@ -518,13 +519,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
fields: Vec::default(),
methods: vec![
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
],
ancestors: Vec::default(),
constructor: None,
resolver: None,
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.copy".into(),
simple_name: "copy".into(),
signature: ndarray_copy_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_copy(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.fill".into(),
simple_name: "fill".into(),

View File

@ -203,6 +203,15 @@ impl TopLevelComposer {
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: ndarray_copy_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
@ -220,6 +229,7 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)),
]),
params: VarMap::from([
@ -228,6 +238,8 @@ impl TopLevelComposer {
]),
});
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
let primitives = PrimitiveStore {
int32,
int64,

View File

@ -654,6 +654,56 @@ fn call_ndarray_eye_impl<'ctx>(
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn ndarray_copy_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
Ok(shape.get_dims().get(ctx, generator, idx, None))
},
)?;
let len = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.get_dims().get_ptr(ctx),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder
.build_int_mul(
len,
sizeof_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ndarray.get_data().get_ptr(ctx),
this.get_data().get_ptr(ctx),
len_bytes,
llvm_i1.const_zero(),
);
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
@ -826,6 +876,36 @@ pub fn gen_ndarray_identity<'ctx>(
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.copy`.
pub fn gen_ndarray_copy<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_some());
assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.fill`.
pub fn gen_ndarray_fill<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [26]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar15]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar15\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [28]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [33]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar14, typevar15]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar14\", \"typevar15\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [34]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [42]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n",
]