WIP
This commit is contained in:
parent
aa673fce4e
commit
d2ce0679ed
|
@ -1130,9 +1130,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
Some("f_pow_i")
|
Some("f_pow_i")
|
||||||
);
|
);
|
||||||
Ok(Some(res.into()))
|
Ok(Some(res.into()))
|
||||||
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
} else if matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) && matches!(&*ctx.unifier.get_ty(ty2), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||||
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||||
|
|
||||||
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
let left_val = NDArrayValue::from_ptr_val(
|
let left_val = NDArrayValue::from_ptr_val(
|
||||||
left_val.into_pointer_value(),
|
left_val.into_pointer_value(),
|
||||||
|
@ -1147,7 +1150,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
ndarray_dtype1,
|
||||||
if is_aug_assign { Some(left_val) } else { None },
|
if is_aug_assign { Some(left_val) } else { None },
|
||||||
left_val,
|
left_val,
|
||||||
right_val,
|
right_val,
|
||||||
|
|
|
@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
Some("N".into()),
|
Some("N".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
let size_t = primitives.0.usize();
|
||||||
|
|
||||||
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
||||||
let exception_fields = vec![
|
let exception_fields = vec![
|
||||||
("__name__".into(), int32, true),
|
("__name__".into(), int32, true),
|
||||||
|
@ -345,6 +347,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
.nth(1)
|
.nth(1)
|
||||||
.map(|(var_id, ty)| (*ty, *var_id))
|
.map(|(var_id, ty)| (*ty, *var_id))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
|
||||||
|
size_t,
|
||||||
|
Some("ndarray_ndims".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
||||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||||
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
|
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
|
||||||
|
@ -699,7 +706,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
name: "ndarray.__iadd__".into(),
|
name: "ndarray.__iadd__".into(),
|
||||||
simple_name: "__iadd__".into(),
|
simple_name: "__iadd__".into(),
|
||||||
signature: ndarray_iadd_ty.0,
|
signature: ndarray_iadd_ty.0,
|
||||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
|
|
@ -285,8 +285,11 @@ impl TopLevelComposer {
|
||||||
]),
|
]),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||||
|
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
|
||||||
|
|
||||||
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
|
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
|
||||||
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray).unwrap();
|
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
|
||||||
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
|
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
|
||||||
|
|
||||||
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
|
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
|
||||||
|
|
|
@ -309,6 +309,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
ndarray: ndarray_t,
|
ndarray: ndarray_t,
|
||||||
..
|
..
|
||||||
} = *store;
|
} = *store;
|
||||||
|
let size_t = store.usize();
|
||||||
|
|
||||||
/* int ======== */
|
/* int ======== */
|
||||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||||
|
@ -345,9 +346,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
|
|
||||||
/* ndarray ===== */
|
/* ndarray ===== */
|
||||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
||||||
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||||
impl_pow(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
|
||||||
|
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||||
|
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t);
|
impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t);
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,6 +81,20 @@ def test_ndarray_add():
|
||||||
output_float64(y[1][0])
|
output_float64(y[1][0])
|
||||||
output_float64(y[1][1])
|
output_float64(y[1][1])
|
||||||
|
|
||||||
|
# def test_ndarray_add_broadcast():
|
||||||
|
# x = np_identity(2)
|
||||||
|
# y: ndarray[float, 2] = x + np_ones([2])
|
||||||
|
#
|
||||||
|
# output_float64(x[0][0])
|
||||||
|
# output_float64(x[0][1])
|
||||||
|
# output_float64(x[1][0])
|
||||||
|
# output_float64(x[1][1])
|
||||||
|
#
|
||||||
|
# output_float64(y[0][0])
|
||||||
|
# output_float64(y[0][1])
|
||||||
|
# output_float64(y[1][0])
|
||||||
|
# output_float64(y[1][1])
|
||||||
|
|
||||||
def test_ndarray_iadd():
|
def test_ndarray_iadd():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
x += np_ones([2, 2])
|
x += np_ones([2, 2])
|
||||||
|
|
Loading…
Reference in New Issue