ndstrides: [5] Implement np_{identity,eye,shape,strides,size} and ndarray.{fill.copy} #515

Open
lyken wants to merge 4 commits from ndstrides-5-miscfuncs into ndstrides-4-nparray
12 changed files with 283 additions and 79 deletions

View File

@ -1905,15 +1905,23 @@ pub fn gen_ndarray_eye<'ctx>(
)) ))
}?; }?;
call_ndarray_eye_impl( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
generator,
context, let nrows = Int(Int32)
context.primitives.float, .check_value(generator, context.ctx, nrows_arg)
nrows_arg.into_int_value(), .unwrap()
ncols_arg.into_int_value(), .s_extend_or_bit_cast(generator, context, SizeT);
offset_arg.into_int_value(), let ncols = Int(Int32)
) .check_value(generator, context.ctx, ncols_arg)
.map(NDArrayValue::into) .unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let offset = Int(Int32)
.check_value(generator, context.ctx, offset_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset);
Ok(ndarray.instance.value)
} }
/// Generates LLVM IR for `ndarray.identity`. /// Generates LLVM IR for `ndarray.identity`.
@ -1927,20 +1935,15 @@ pub fn gen_ndarray_identity<'ctx>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx); let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let n_ty = fun.0.args[0].ty; let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl( let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap();
generator, let n = n.s_extend_or_bit_cast(generator, context, SizeT);
context, let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n);
context.primitives.float, Ok(ndarray.instance.value)
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
)
.map(NDArrayValue::into)
} }
/// Generates LLVM IR for `ndarray.copy`. /// Generates LLVM IR for `ndarray.copy`.
@ -1954,20 +1957,14 @@ pub fn gen_ndarray_copy<'ctx>(
assert!(obj.is_some()); assert!(obj.is_some());
assert!(args.is_empty()); assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl( let this = AnyObject { value: this_arg, ty: this_ty };
generator, let this = NDArrayObject::from_object(generator, context, this);
context, let ndarray = this.make_copy(generator, context);
this_elem_ty, Ok(ndarray.instance.value)
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
} }
/// Generates LLVM IR for `ndarray.fill`. /// Generates LLVM IR for `ndarray.fill`.
@ -1981,48 +1978,15 @@ pub fn gen_ndarray_fill<'ctx>(
assert!(obj.is_some()); assert!(obj.is_some());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj let this_arg =
.as_ref() obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty; let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened( let this = AnyObject { value: this_arg, ty: this_ty };
generator, let this = NDArrayObject::from_object(generator, context, this);
context, this.fill(generator, context, value_arg);
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
codegen_unreachable!(ctx)
};
Ok(value)
},
)?;
Ok(()) Ok(())
} }

View File

@ -1,4 +1,4 @@
use inkwell::values::BasicValueEnum; use inkwell::{values::BasicValueEnum, IntPredicate};
use super::NDArrayObject; use super::NDArrayObject;
use crate::{ use crate::{
@ -122,4 +122,54 @@ impl<'ctx> NDArrayObject<'ctx> {
let fill_value = ndarray_one_value(generator, ctx, dtype); let fill_value = ndarray_one_value(generator, ctx, dtype);
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
} }
/// Create an ndarray like `np.eye`.
pub fn make_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: Instance<'ctx, Int<SizeT>>,
ncols: Instance<'ctx, Int<SizeT>>,
offset: Instance<'ctx, Int<SizeT>>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
// Load up `row_i` and `col_i` from indices.
let row_i = nditer.get_indices().get_index_const(generator, ctx, 0);
let col_i = nditer.get_indices().get_index_const(generator, ctx, 1);
let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i);
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.identity`.
pub fn make_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Instance<'ctx, Int<SizeT>>,
) -> Self {
// Convenient implementation
let offset = Int(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset)
}
} }

View File

@ -5,7 +5,7 @@ use inkwell::{
AddressSpace, AddressSpace,
}; };
use super::any::AnyObject; use super::{any::AnyObject, tuple::TupleObject};
use crate::{ use crate::{
codegen::{ codegen::{
irrt::{ irrt::{
@ -417,6 +417,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>, value: BasicValueEnum<'ctx>,
) { ) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx); let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap(); ctx.builder.build_store(p, value).unwrap();
@ -424,6 +426,62 @@ impl<'ctx> NDArrayObject<'ctx> {
}) })
.unwrap(); .unwrap();
} }
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
/// Create the strides tuple of this ndarray like `<ndarray>.strides`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
} }
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both. /// A convenience enum for implementing functions that acts on scalars or ndarrays or both.

View File

@ -19,7 +19,9 @@ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ProxyValue, RangeValue}, classes::{ProxyValue, RangeValue},
model::*,
numpy::*, numpy::*,
object::{any::AnyObject, ndarray::NDArrayObject},
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
@ -512,6 +514,10 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye | PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
self.build_ndarray_property_getter_function(prim)
}
PrimDef::FunStr => self.build_str_function(), PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -1386,6 +1392,78 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpSize => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
self.primitives.int32,
&[(in_ndarray_ty.ty, "a")],
Box::new(|ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size =
ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32);
Ok(Some(size.value.as_basic_value_enum()))
}),
),
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
// The function signatures of `np_shape` an `np_size` are the same.
// Mixed together for convenience.
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
ret_ty,
&[(in_ndarray_ty.ty, "a")],
Box::new(move |ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let result_tuple = match prim {
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
_ => unreachable!(),
};
Ok(Some(result_tuple.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build the `str()` function. /// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef { fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr; let prim = PrimDef::FunStr;
@ -1888,8 +1966,8 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&into_var_map([ndarray_ty]), &into_var_map([ndarray_ty]),
prim.name(), prim.name(),
ndarray_ty.ty, self.ndarray_num_ty,
&[(ndarray_ty.ty, "x")], &[(self.ndarray_num_ty, "x")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg_val = let arg_val =

View File

@ -54,6 +54,11 @@ pub enum PrimDef {
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
// NumPy ndarray property getters
FunNpSize,
FunNpShape,
FunNpStrides,
// Miscellaneous NumPy & SciPy functions // Miscellaneous NumPy & SciPy functions
FunNpRound, FunNpRound,
FunNpFloor, FunNpFloor,
@ -240,6 +245,11 @@ impl PrimDef {
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
// NumPy NDArray property getters,
PrimDef::FunNpSize => fun("np_size", None),
PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None),
// Miscellaneous NumPy & SciPy functions // Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
] ]

View File

@ -3,7 +3,7 @@ use std::{
cmp::max, cmp::max,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::{From, TryInto}, convert::{From, TryInto},
iter::once, iter::{self, once},
sync::Arc, sync::Arc,
}; };
@ -1235,6 +1235,45 @@ impl<'a> Inferencer<'a> {
})); }));
} }
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() { if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?;

View File

@ -179,6 +179,11 @@ def patch(module):
module.np_identity = np.identity module.np_identity = np.identity
module.np_array = np.array module.np_array = np.array
# NumPy NDArray property getters
module.np_size = np.size
module.np_shape = np.shape
module.np_strides = lambda ndarray: ndarray.strides
# NumPy Math functions # NumPy Math functions
module.np_isnan = np.isnan module.np_isnan = np.isnan
module.np_isinf = np.isinf module.np_isinf = np.isinf