forked from M-Labs/nac3
core: support tuple and int32 input for np_empty, np_ones, and more
This commit is contained in:
parent
b21df53e0d
commit
5b11a1dbdd
|
@ -163,10 +163,11 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
for shape_dim in shape {
|
for &shape_dim in shape {
|
||||||
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
let shape_dim_gez = ctx
|
let shape_dim_gez = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
|
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
|
@ -189,7 +190,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
for (i, shape_dim) in shape.iter().enumerate() {
|
for (i, &shape_dim) in shape.iter().enumerate() {
|
||||||
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
let ndarray_dim = unsafe {
|
let ndarray_dim = unsafe {
|
||||||
ndarray.dim_sizes().ptr_offset_unchecked(
|
ndarray.dim_sizes().ptr_offset_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -199,7 +201,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
|
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||||
|
@ -286,22 +288,68 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
|
///
|
||||||
|
/// ### Notes on `shape`
|
||||||
|
///
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
||||||
|
/// learn how `shape` gets from being a Python user expression to here.
|
||||||
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match shape {
|
||||||
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||||
|
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||||
create_ndarray_dyn_shape(
|
create_ndarray_dyn_shape(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
elem_ty,
|
elem_ty,
|
||||||
&shape,
|
&shape_list,
|
||||||
|_, ctx, shape| Ok(shape.load_size(ctx, None)),
|
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|
||||||
|generator, ctx, shape, idx| {
|
|generator, ctx, shape_list, idx| {
|
||||||
Ok(shape.data().get(ctx, generator, &idx, None).into_int_value())
|
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
BasicValueEnum::StructValue(shape_tuple) => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||||
|
|
||||||
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
|
||||||
|
let mut shape = Vec::with_capacity(ndims as usize);
|
||||||
|
for dim_i in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
|
||||||
|
shape.push(dim);
|
||||||
|
}
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
|
}
|
||||||
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||||
|
@ -486,7 +534,7 @@ fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let supported_types = [
|
let supported_types = [
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
|
@ -517,7 +565,7 @@ fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let supported_types = [
|
let supported_types = [
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
|
@ -548,7 +596,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||||
|
@ -1674,16 +1722,10 @@ pub fn gen_ndarray_empty<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_empty_impl(
|
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
|
||||||
generator,
|
|
||||||
context,
|
|
||||||
context.primitives.float,
|
|
||||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
.map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1698,16 +1740,10 @@ pub fn gen_ndarray_zeros<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_zeros_impl(
|
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
|
||||||
generator,
|
|
||||||
context,
|
|
||||||
context.primitives.float,
|
|
||||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
.map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1722,16 +1758,10 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_ones_impl(
|
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
|
||||||
generator,
|
|
||||||
context,
|
|
||||||
context.primitives.float,
|
|
||||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
.map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1746,20 +1776,13 @@ pub fn gen_ndarray_full<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 2);
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
let fill_value_ty = fun.0.args[1].ty;
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
let fill_value_arg =
|
let fill_value_arg =
|
||||||
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
call_ndarray_full_impl(
|
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
|
||||||
generator,
|
|
||||||
context,
|
|
||||||
fill_value_ty,
|
|
||||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
|
||||||
fill_value_arg,
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
.map(NDArrayValue::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -324,6 +324,9 @@ struct BuiltinBuilder<'a> {
|
||||||
|
|
||||||
num_or_ndarray_ty: TypeVar,
|
num_or_ndarray_ty: TypeVar,
|
||||||
num_or_ndarray_var_map: VarMap,
|
num_or_ndarray_var_map: VarMap,
|
||||||
|
|
||||||
|
/// See [`BuiltinBuilder::build_ndarray_from_shape_factory_function`]
|
||||||
|
ndarray_factory_fn_shape_arg_tvar: TypeVar,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> BuiltinBuilder<'a> {
|
impl<'a> BuiltinBuilder<'a> {
|
||||||
|
@ -394,6 +397,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
|
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
|
||||||
|
|
||||||
|
let ndarray_factory_fn_shape_arg_tvar = unifier.get_fresh_var(Some("Shape".into()), None);
|
||||||
|
|
||||||
BuiltinBuilder {
|
BuiltinBuilder {
|
||||||
unifier,
|
unifier,
|
||||||
primitives,
|
primitives,
|
||||||
|
@ -421,6 +426,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
num_or_ndarray_ty,
|
num_or_ndarray_ty,
|
||||||
num_or_ndarray_var_map,
|
num_or_ndarray_var_map,
|
||||||
|
|
||||||
|
ndarray_factory_fn_shape_arg_tvar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -959,21 +966,46 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build ndarray factory functions that only take in an argument `shape` of type `list[int32]` and return an ndarray.
|
/// Build ndarray factory functions that only take in an argument `shape`.
|
||||||
|
///
|
||||||
|
/// `shape` can be a tuple of int32s, a list of int32s, or a scalar int32.
|
||||||
fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(
|
debug_assert_prim_is_allowed(
|
||||||
prim,
|
prim,
|
||||||
&[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes],
|
&[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
|
//
|
||||||
|
// Ideally, we should have created a [`TypeVar`] to define all possible input
|
||||||
|
// types for the parameter "shape" like so:
|
||||||
|
// ```rust
|
||||||
|
// self.unifier.get_fresh_var_with_range(
|
||||||
|
// &[int32, list_int32, /* and more... */],
|
||||||
|
// Some("T".into()), None)
|
||||||
|
// )
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// However, there is (currently) no way to type a tuple of arbitrary length in `nac3core`.
|
||||||
|
//
|
||||||
|
// And this is the best we could do:
|
||||||
|
// ```rust
|
||||||
|
// &[ int32, list_int32, tuple_1_int32, tuple_2_int32, tuple_3_int32, ... ],
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// But this is not ideal.
|
||||||
|
//
|
||||||
|
// Instead, we delegate the responsibility of typechecking
|
||||||
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
|
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
&VarMap::new(),
|
&VarMap::new(),
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.ndarray_float,
|
self.ndarray_float,
|
||||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
// type variable
|
|
||||||
&[(self.list_int32, "shape")],
|
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(240)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -814,6 +814,150 @@ impl<'a> Inferencer<'a> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fold an ndarray `shape` argument. This function aims to fold `shape` arguments like that of
|
||||||
|
/// <https://numpy.org/doc/stable/reference/generated/numpy.zeros.html> (for `np_zeros`).
|
||||||
|
///
|
||||||
|
/// Arguments:
|
||||||
|
/// * `id` - The name of the function of the function call this `shape` argument is in. Used for error reporting.
|
||||||
|
/// * `arg_index` - The position (0-based) of this argument in the function call. Used for error reporting.
|
||||||
|
/// * `shape_expr` - [`Located<ExprKind>`] of the input argument.
|
||||||
|
///
|
||||||
|
/// On success, it returns a tuple of
|
||||||
|
/// 1) the `ndims` value inferred from the input `shape`,
|
||||||
|
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
|
||||||
|
fn fold_numpy_function_call_shape_argument(
|
||||||
|
&mut self,
|
||||||
|
id: StrRef,
|
||||||
|
arg_index: usize,
|
||||||
|
shape_expr: Located<ExprKind>,
|
||||||
|
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
|
||||||
|
/*
|
||||||
|
### Further explanation
|
||||||
|
|
||||||
|
As said, this function aims to fold `shape` arguments, but this is *not* trivial.
|
||||||
|
The root of the issue is that `nac3core` has to deduce the `ndims`
|
||||||
|
of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
|
||||||
|
|
||||||
|
There are three types of valid input to `shape`:
|
||||||
|
1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
|
||||||
|
2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
|
||||||
|
3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
|
||||||
|
|
||||||
|
For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
|
||||||
|
- For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
|
||||||
|
- For 3. `ndims` is simply 1.
|
||||||
|
|
||||||
|
For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
|
||||||
|
is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
|
||||||
|
itself to extract the input list length statically.
|
||||||
|
|
||||||
|
This implies that the user could only write:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_rgba_image = np_zeros([600, 800, 4])
|
||||||
|
# the shape argument is directly written as a list literal.
|
||||||
|
# and `nac3core` could therefore tell that ndims is `3` by
|
||||||
|
# looking at the raw AST expression itself.
|
||||||
|
```
|
||||||
|
|
||||||
|
But not:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_image_dimension = [600, 800, 4]
|
||||||
|
mystery_function_that_mutates_my_list(my_image_dimension)
|
||||||
|
my_image = np_zeros(my_image_dimension)
|
||||||
|
# what is the length now? what is `ndims`?
|
||||||
|
|
||||||
|
# it is *basically impossible* to generally determine the
|
||||||
|
# length of `my_image_dimension` statically for `ndims`!!
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Fold `shape`
|
||||||
|
let shape = self.fold_expr(shape_expr)?;
|
||||||
|
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
|
||||||
|
|
||||||
|
// Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32.
|
||||||
|
// Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`.
|
||||||
|
//
|
||||||
|
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
|
||||||
|
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
|
||||||
|
let ndims = match shape_ty_enum {
|
||||||
|
TypeEnum::TList { ty } => {
|
||||||
|
// Handle 1. A list of int32s
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
self.unifier.unify(*ty, self.primitives.int32).map_err(|err| {
|
||||||
|
HashSet::from([err
|
||||||
|
.at(Some(shape.location))
|
||||||
|
.to_display(self.unifier)
|
||||||
|
.to_string()])
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Special handling for (1. A python `List` (all `int32s`)).
|
||||||
|
// Read the doc above this function to see what is going on here.
|
||||||
|
if let ExprKind::List { elts, .. } = &shape.node {
|
||||||
|
// The user wrote a List literal as the input argument
|
||||||
|
elts.len() as u64
|
||||||
|
} else {
|
||||||
|
// This means the user is passing an expression of type `List`,
|
||||||
|
// but it is done so indirectly (like putting a variable referencing a `List`)
|
||||||
|
// rather than writing a List literal. We need to report an error.
|
||||||
|
return Err(HashSet::from([
|
||||||
|
format!(
|
||||||
|
"Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {location}. Input argument is of type list but not a list literal.",
|
||||||
|
arg_num = arg_index + 1,
|
||||||
|
location = shape.location
|
||||||
|
)
|
||||||
|
]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty: tuple_element_types } => {
|
||||||
|
// Handle 2. A tuple of int32s
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
// The expected type is just the tuple but with all its elements being int32.
|
||||||
|
let expected_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||||
|
ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(),
|
||||||
|
});
|
||||||
|
self.unifier.unify(shape_ty, expected_ty).map_err(|err| {
|
||||||
|
HashSet::from([err
|
||||||
|
.at(Some(shape.location))
|
||||||
|
.to_display(self.unifier)
|
||||||
|
.to_string()])
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// `ndims` can be deduced statically from the inferred Tuple type.
|
||||||
|
tuple_element_types.len() as u64
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { .. } => {
|
||||||
|
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
self.unify(self.primitives.int32, shape_ty, &shape.location)?;
|
||||||
|
|
||||||
|
// Deduce `ndims`
|
||||||
|
1
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// The user wrote an ill-typed `shape_expr`,
|
||||||
|
// so throw an error.
|
||||||
|
let shape_ty_str = self.unifier.stringify(shape_ty);
|
||||||
|
return report_error(
|
||||||
|
format!(
|
||||||
|
"Expected list literal, tuple, or int32 for argument {arg_num} of {id}, got {shape_expr_name} of type {shape_ty_str}",
|
||||||
|
arg_num = arg_index + 1,
|
||||||
|
shape_expr_name = shape.node.name(),
|
||||||
|
)
|
||||||
|
.as_str(),
|
||||||
|
shape.location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((ndims, shape))
|
||||||
|
}
|
||||||
|
|
||||||
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
|
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
|
||||||
/// returns [`None`].
|
/// returns [`None`].
|
||||||
fn try_fold_special_call(
|
fn try_fold_special_call(
|
||||||
|
@ -1141,25 +1285,15 @@ impl<'a> Inferencer<'a> {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1-argument ndarray n-dimensional creation functions
|
// 1-argument ndarray n-dimensional factory functions
|
||||||
if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()]
|
if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()]
|
||||||
.contains(id)
|
.contains(id)
|
||||||
&& args.len() == 1
|
&& args.len() == 1
|
||||||
{
|
{
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let shape_expr = args.remove(0);
|
||||||
return report_error(
|
let (ndims, shape) =
|
||||||
format!(
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling the `shape`
|
||||||
"Expected List literal for first argument of {id}, got {}",
|
|
||||||
args[0].node.name()
|
|
||||||
)
|
|
||||||
.as_str(),
|
|
||||||
args[0].location,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndims = elts.len() as u64;
|
|
||||||
|
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
|
||||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
let ret = make_ndarray_ty(
|
let ret = make_ndarray_ty(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
|
@ -1170,7 +1304,7 @@ impl<'a> Inferencer<'a> {
|
||||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
name: "shape".into(),
|
name: "shape".into(),
|
||||||
ty: arg0.custom.unwrap(),
|
ty: shape.custom.unwrap(),
|
||||||
default_value: None,
|
default_value: None,
|
||||||
}],
|
}],
|
||||||
ret,
|
ret,
|
||||||
|
@ -1186,7 +1320,7 @@ impl<'a> Inferencer<'a> {
|
||||||
location: func.location,
|
location: func.location,
|
||||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
}),
|
}),
|
||||||
args: vec![arg0],
|
args: vec![shape],
|
||||||
keywords: vec![],
|
keywords: vec![],
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -71,17 +71,44 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_ndarray_ctor():
|
def test_ndarray_ctor():
|
||||||
n: ndarray[float, Literal[1]] = np_ndarray([1])
|
n: ndarray[float, Literal[1]] = np_ndarray([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
def test_ndarray_empty():
|
def test_ndarray_empty():
|
||||||
n: ndarray[float, 1] = np_empty([1])
|
n1: ndarray[float, 1] = np_empty([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n1)
|
||||||
|
|
||||||
|
n2: ndarray[float, 1] = np_empty(10)
|
||||||
|
consume_ndarray_1(n2)
|
||||||
|
|
||||||
|
n3: ndarray[float, 1] = np_empty((2,))
|
||||||
|
consume_ndarray_1(n3)
|
||||||
|
|
||||||
|
n4: ndarray[float, 2] = np_empty((4, 4))
|
||||||
|
consume_ndarray_2(n4)
|
||||||
|
|
||||||
|
dim4 = (5, 2)
|
||||||
|
n5: ndarray[float, 2] = np_empty(dim4)
|
||||||
|
consume_ndarray_2(n5)
|
||||||
|
|
||||||
def test_ndarray_zeros():
|
def test_ndarray_zeros():
|
||||||
n: ndarray[float, 1] = np_zeros([1])
|
n1: ndarray[float, 1] = np_zeros([1])
|
||||||
output_ndarray_float_1(n)
|
output_ndarray_float_1(n1)
|
||||||
|
|
||||||
|
k = 3 + int32(n1[0]) # to test variable shape inputs
|
||||||
|
n2: ndarray[float, 1] = np_zeros(k * k)
|
||||||
|
output_ndarray_float_1(n2)
|
||||||
|
|
||||||
|
n3: ndarray[float, 1] = np_zeros((k * 2,))
|
||||||
|
output_ndarray_float_1(n3)
|
||||||
|
|
||||||
|
dim4 = (3, 2 * k)
|
||||||
|
n4: ndarray[float, 2] = np_zeros(dim4)
|
||||||
|
output_ndarray_float_2(n4)
|
||||||
|
|
||||||
def test_ndarray_ones():
|
def test_ndarray_ones():
|
||||||
n: ndarray[float, 1] = np_ones([1])
|
n: ndarray[float, 1] = np_ones([1])
|
||||||
|
|
Loading…
Reference in New Issue