Implement ndarray.array #412

Merged
sb10q merged 7 commits from enhance/issue-149-ndarray/numpy-array-func into master 2024-06-11 19:30:35 +08:00
13 changed files with 643 additions and 18 deletions

View File

@ -541,18 +541,20 @@ impl<'ctx> ListType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}

View File

@ -1,12 +1,16 @@
use inkwell::{IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}};
use inkwell::{AddressSpace, IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}};
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use nac3parser::ast::{Operator, StrRef};
use crate::{
codegen::{
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ListType,
ListValue,
NDArrayType,
NDArrayValue,
ProxyType,
ProxyValue,
TypedArrayLikeAccessor,
TypedArrayLikeAdapter,
@ -31,9 +35,10 @@ use crate::{
symbol_resolver::ValueEnum,
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::{FunSignature, Type},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
/// Creates an uninitialized `NDArray` instance.
@ -589,6 +594,405 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ty: PointerType<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let list_ty = ListType::from_type(ty, llvm_usize);
let list_elem_ty = list_ty.element_type();
let ndims = llvm_usize.const_int(1, false);
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => {
ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty))
}
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => {
todo!("Getting ndims for list[ndarray] not supported")
}
_ => ndims,
}
}
/// Returns the number of dimensions for an array-like object as an [`IntValue`].
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
match value {
BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => {
NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx)
}
BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => {
llvm_ndlist_get_ndims(generator, ctx, v.get_type())
}
_ => llvm_usize.const_zero(),
}
}
/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`].
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
src_lst: ListValue<'ctx>,
dim: u64,
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let list_elem_ty = src_lst.get_type().element_type();
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => {
// The stride of elements in this dimension, i.e. the number of elements between arr[i]
// and arr[i + 1] in this dimension
let stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
gen_for_range_callback(
generator,
ctx,
true,
|_, _| Ok(llvm_usize.const_zero()),
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, i| {
let offset = ctx.builder.build_int_mul(
stride,
i,
"",
).unwrap();
let dst_ptr = unsafe {
ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap()
};
let nested_lst_elem = ListValue::from_ptr_val(
unsafe {
src_lst.data().get_unchecked(ctx, generator, &i, None)
}.into_pointer_value(),
llvm_usize,
None,
);
ndarray_from_ndlist_impl(
generator,
ctx,
elem_ty,
(dst_arr, dst_ptr),
nested_lst_elem,
dim + 1,
)?;
Ok(())
},
)?;
}
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => {
todo!("Not implemented for list[ndarray]")
}
_ => {
let lst_len = src_lst.load_size(ctx, None);
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let cpy_len = ctx.builder.build_int_mul(
ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(),
sizeof_elem,
""
).unwrap();
call_memcpy_generic(
ctx,
dst_slice_ptr,
src_lst.data().base_ptr(ctx, generator),
cpy_len,
llvm_i1.const_zero(),
);
}
}
Ok(())
}
/// LLVM-typed implementation for `ndarray.array`.
fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
object: BasicValueEnum<'ctx>,
copy: IntValue<'ctx>,
ndmin: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndmin = ctx.builder
.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "")
.unwrap();
// TODO(Derppening): Add assertions for sizes of different dimensions
// object is not a pointer - 0-dim NDArray
if !object.is_pointer_value() {
let ndarray = create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[],
)?;
unsafe {
ndarray.data()
.set_unchecked(ctx, generator, &llvm_usize.const_zero(), object);
}
return Ok(ndarray)
}
let object = object.into_pointer_value();
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_instance(object, llvm_usize).is_ok() {
let object = NDArrayValue::from_ptr_val(object, llvm_usize, None);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
let copy_nez = ctx.builder
.build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "")
.unwrap();
let ndmin_gt_ndims = ctx.builder
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
.unwrap();
Ok(ctx.builder
.build_and(copy_nez, ndmin_gt_ndims, "")
.unwrap())
},
|generator, ctx| {
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&object,
|_, ctx, object| {
let ndims = object.load_ndims(ctx);
let ndmin_gt_ndims = ctx.builder
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
.unwrap();
Ok(ctx.builder
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
.map(BasicValueEnum::into_int_value)
.unwrap())
},
|generator, ctx, object, idx| {
let ndims = object.load_ndims(ctx);
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
// The number of dimensions to prepend 1's to
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder
.build_int_compare(IntPredicate::UGE, idx, offset, "")
.unwrap())
},
|_, _| {
Ok(Some(llvm_usize.const_int(1, false)))
},
|_, ctx| {
Ok(Some(ctx.builder.build_int_sub(
idx,
offset,
""
).unwrap()))
},
)?.map(BasicValueEnum::into_int_value).unwrap())
},
)?;
ndarray_sliced_copyto_impl(
generator,
ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(object, object.data().base_ptr(ctx, generator)),
0,
&[],
)?;
Ok(Some(ndarray.as_base_value()))
},
|_, _| {
Ok(Some(object.as_base_value()))
},
)?;
return Ok(NDArrayValue::from_ptr_val(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_usize,
None,
))
}
// Remaining case: TList
assert!(ListValue::is_instance(object, llvm_usize).is_ok());
let object = ListValue::from_ptr_val(object, llvm_usize, None);
// The number of dimensions to prepend 1's to
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&object,
|generator, ctx, object| {
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
let ndmin_gt_ndims = ctx.builder
.build_int_compare(IntPredicate::UGT, ndmin, ndims, "")
.unwrap();
Ok(ctx.builder
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
.map(BasicValueEnum::into_int_value)
.unwrap())
},
|generator, ctx, object, idx| {
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder
.build_int_compare(IntPredicate::ULT, idx, offset, "")
.unwrap())
},
|_, _| {
Ok(Some(llvm_usize.const_int(1, false)))
},
|generator, ctx| {
let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| {
ctx.ctx.struct_type(
&[
elem_ty.ptr_type(AddressSpace::default()).into(),
llvm_usize.into(),
],
false,
)
};
let llvm_i8 = ctx.ctx.i8_type();
let llvm_list_i8 = make_llvm_list(llvm_i8.into());
let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default());
// Cast list to { i8*, usize } since we only care about the size
let lst = generator.gen_var_alloc(
ctx,
ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(),
None,
).unwrap();
ctx.builder.build_store(
lst,
ctx.builder.build_bitcast(
object.as_base_value(),
llvm_plist_i8,
"",
).unwrap(),
).unwrap();
let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap();
gen_for_range_callback(
generator,
ctx,
true,
|_, _| Ok(llvm_usize.const_zero()),
(|_, _| Ok(stop), false),
|_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _| {
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
.ptr_type(AddressSpace::default());
let this_dim = ctx.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
.map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap())
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let this_dim = ListValue::from_ptr_val(
this_dim,
llvm_usize,
None,
);
// TODO: Assert this_dim.sz != 0
let next_dim = unsafe {
this_dim.data()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}.into_pointer_value();
ctx.builder.build_store(
lst,
ctx.builder.build_bitcast(
next_dim,
llvm_plist_i8,
"",
).unwrap(),
).unwrap();
Ok(())
},
)?;
let lst = ListValue::from_ptr_val(
ctx.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap(),
llvm_usize,
None,
);
Ok(Some(lst.load_size(ctx, None)))
},
)?.map(BasicValueEnum::into_int_value).unwrap())
},
)?;
ndarray_from_ndlist_impl(
generator,
ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
object,
0,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -1450,6 +1854,69 @@ pub fn gen_ndarray_full<'ctx>(
).map(NDArrayValue::into)
}
pub fn gen_ndarray_array<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
}
TypeEnum::TList { ty } => {
let mut ty = *ty;
while let TypeEnum::TList { ty: elem_ty } = &*context.unifier.get_ty_immutable(ty) {
ty = *elem_ty;
}
ty
},
_ => obj_ty,
};
let obj_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, obj_ty)?;
let copy_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) {
let copy_ty = fun.0.args[1].ty;
arg.1.clone().to_basic_value_enum(context, generator, copy_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[1].default_value.as_ref().unwrap(),
fun.0.args[1].ty,
)
};
let ndmin_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) {
let ndmin_ty = fun.0.args[2].ty;
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
fun.0.args[2].ty,
)
};
call_ndarray_array_impl(
generator,
context,
obj_elem_ty,
obj_arg,
copy_arg.into_int_value(),
ndmin_arg.into_int_value(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,

View File

@ -872,12 +872,14 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
ctx.builder.position_at_end(then_bb);
let then_val = then_fn(generator, ctx)?;
let then_end_bb = ctx.builder.get_insert_block().unwrap();
if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(end_bb).unwrap();
}
ctx.builder.position_at_end(else_bb);
let else_val = else_fn(generator, ctx)?;
let else_end_bb = ctx.builder.get_insert_block().unwrap();
if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(end_bb).unwrap();
}
@ -889,7 +891,7 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
assert_eq!(tv_ty, ev.as_basic_value_enum().get_type());
let phi = ctx.builder.build_phi(tv_ty, "").unwrap();
phi.add_incoming(&[(&tv, then_bb), (&ev, else_bb)]);
phi.add_incoming(&[(&tv, then_end_bb), (&ev, else_end_bb)]);
Some(phi.as_basic_value())
},

View File

@ -788,6 +788,42 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
}),
)
},
{
let tv = unifier.get_fresh_var(Some("T".into()), None);
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_array".into(),
simple_name: "np_array".into(),
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "object".into(), ty: tv.0, default_value: None },
FuncArg {
name: "copy".into(),
ty: boolean,
default_value: Some(SymbolValue::Bool(true)),
},
FuncArg {
name: "ndmin".into(),
ty: int32,
default_value: Some(SymbolValue::U32(0)),
},
],
ret: ndarray,
vars: VarMap::from([(tv.1, tv.0)]),
})),
var_id: vec![tv.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_array(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
},
)))),
loc: None,
}))
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_eye".into(),
simple_name: "np_eye".into(),

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
@ -691,3 +692,35 @@ pub fn parse_parameter_default_value(
]))
}
}
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray =>
unpack_ndarray_var_tys(unifier, ty).0,
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
_ => ty
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
};
if values.len() > 1 {
todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented")
}
u64::try_from(values[0].clone()).unwrap()
}
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
_ => 0
}
}

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [239]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [241]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [246]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [247]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [255]\n}\n",
]

View File

@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
@ -1254,6 +1254,77 @@ impl<'a> Inferencer<'a> {
}))
}
// 1-argument ndarray n-dimensional creation functions
if id == &"np_array".into() && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?;
let keywords = keywords.iter()
.map(|v| fold::fold_keyword(self, v.clone()))
.collect::<Result<Vec<_>, _>>()?;
let ndmin_kw = keywords.iter()
.find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
let ndims = if let Some(ndmin_kw) = ndmin_kw {
match &ndmin_kw.node.value.node {
ExprKind::Constant { value, .. } => match value {
ast::Constant::Int(value) => *value as u64,
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
}
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
}
} else {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
};
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
);
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "object".into(),
ty: arg0.custom.unwrap(),
default_value: None
},
FuncArg {
name: "copy".into(),
ty: self.primitives.bool,
default_value: Some(SymbolValue::Bool(true)),
},
FuncArg {
name: "ndmin".into(),
ty: self.primitives.int32,
default_value: Some(SymbolValue::U32(0)),
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords,
},
}))
}
Ok(None)
}
@ -1264,11 +1335,10 @@ impl<'a> Inferencer<'a> {
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
return Ok(spec_call_func)
} else {
func
};
}
let func = Box::new(self.fold_expr(func)?);
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
let keywords = keywords

View File

@ -170,6 +170,7 @@ def patch(module):
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
module.np_array = np.array
# NumPy Math functions
module.np_isnan = np.isnan

View File

@ -97,6 +97,19 @@ def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2)
output_ndarray_float_2(n)
def test_ndarray_array():
n1: ndarray[float, 1] = np_array([1.0, 2.0, 3.0])
output_ndarray_float_1(n1)
n1to2: ndarray[float, 2] = np_array([1.0, 2.0, 3.0], ndmin=2)
output_ndarray_float_2(n1to2)
n2: ndarray[float, 2] = np_array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
output_ndarray_float_2(n2)
# Copy
n2_cpy: ndarray[float, 2] = np_array(n2, copy=False)
n2_cpy.fill(0.0)
output_ndarray_float_2(n2_cpy)
def test_ndarray_identity():
n: ndarray[float, 2] = np_identity(2)
output_ndarray_float_2(n)
@ -1373,6 +1386,7 @@ def run() -> int32:
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_array()
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()