forked from M-Labs/nac3
core: Implement `ndarray.copy`
This commit is contained in:
parent
a94927a11d
commit
c3b122acfc
|
@ -342,6 +342,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
.nth(1)
|
||||
.map(|(var_id, ty)| (*ty, *var_id))
|
||||
.unwrap();
|
||||
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||
|
||||
let top_level_def_list = vec![
|
||||
|
@ -518,13 +519,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty],
|
||||
fields: Vec::default(),
|
||||
methods: vec![
|
||||
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
|
||||
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
|
||||
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
|
||||
],
|
||||
ancestors: Vec::default(),
|
||||
constructor: None,
|
||||
resolver: None,
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.copy".into(),
|
||||
simple_name: "copy".into(),
|
||||
signature: ndarray_copy_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_copy(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.fill".into(),
|
||||
simple_name: "fill".into(),
|
||||
|
|
|
@ -203,6 +203,15 @@ impl TopLevelComposer {
|
|||
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: ndarray_copy_fun_ret_ty.0,
|
||||
vars: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
}));
|
||||
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
|
@ -220,6 +229,7 @@ impl TopLevelComposer {
|
|||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
fields: Mapping::from([
|
||||
("copy".into(), (ndarray_copy_fun_ty, true)),
|
||||
("fill".into(), (ndarray_fill_fun_ty, true)),
|
||||
]),
|
||||
params: VarMap::from([
|
||||
|
@ -228,6 +238,8 @@ impl TopLevelComposer {
|
|||
]),
|
||||
});
|
||||
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
|
||||
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
int64,
|
||||
|
|
|
@ -654,6 +654,56 @@ fn call_ndarray_eye_impl<'ctx>(
|
|||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
fn ndarray_copy_impl<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
this: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let ndarray = create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&this,
|
||||
|_, ctx, shape| {
|
||||
Ok(shape.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, shape, idx| {
|
||||
Ok(shape.get_dims().get(ctx, generator, idx, None))
|
||||
},
|
||||
)?;
|
||||
|
||||
let len = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let len_bytes = ctx.builder
|
||||
.build_int_mul(
|
||||
len,
|
||||
sizeof_ty.size_of().unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.get_data().get_ptr(ctx),
|
||||
this.get_data().get_ptr(ctx),
|
||||
len_bytes,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -826,6 +876,36 @@ pub fn gen_ndarray_identity<'ctx>(
|
|||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.copy`.
|
||||
pub fn gen_ndarray_copy<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
_fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_some());
|
||||
assert!(args.is_empty());
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
|
||||
let this_arg = obj
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(context, generator, this_ty)?;
|
||||
|
||||
ndarray_copy_impl(
|
||||
generator,
|
||||
context,
|
||||
this_elem_ty,
|
||||
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.fill`.
|
||||
pub fn gen_ndarray_fill<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [26]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar15]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar15\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [28]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [33]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar14, typevar15]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar14\", \"typevar15\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [34]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [42]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue