forked from M-Labs/nac3
[core] codegen/ndarray: Implement np_{shape,strides}
Based on 40c24486: core/ndstrides: implement np_shape() and np_strides() These functions are not important, but they are handy for debugging. `np.strides()` is not an actual NumPy function, but `ndarray.strides` is used.
This commit is contained in:
parent
9ffa2d6552
commit
12358c57b1
@ -1,19 +1,23 @@
|
|||||||
|
use std::iter::repeat_n;
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor,
|
||||||
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||||
|
UntypedArrayLikeMutator,
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
irrt,
|
irrt,
|
||||||
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
type_aligned_alloca,
|
type_aligned_alloca,
|
||||||
types::{ndarray::NDArrayType, structure::StructField},
|
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
pub use contiguous::*;
|
pub use contiguous::*;
|
||||||
@ -417,13 +421,85 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create the shape tuple of this ndarray like
|
||||||
|
/// [`np.shape(<ndarray>)`](https://numpy.org/doc/stable/reference/generated/numpy.shape.html).
|
||||||
|
///
|
||||||
|
/// All elements in the tuple are `i32`.
|
||||||
|
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> TupleValue<'ctx> {
|
||||||
|
assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let objects = (0..self.ndims.unwrap())
|
||||||
|
.map(|i| {
|
||||||
|
let dim = unsafe {
|
||||||
|
self.shape().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&self.llvm_usize.const_int(i, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap()
|
||||||
|
})
|
||||||
|
.map(|obj| obj.as_basic_value_enum())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
TupleType::new(
|
||||||
|
generator,
|
||||||
|
ctx.ctx,
|
||||||
|
&repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(),
|
||||||
|
)
|
||||||
|
.construct_from_objects(ctx, objects, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create the strides tuple of this ndarray like
|
||||||
|
/// [`<ndarray>.strides`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html).
|
||||||
|
///
|
||||||
|
/// All elements in the tuple are `i32`.
|
||||||
|
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> TupleValue<'ctx> {
|
||||||
|
assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let objects = (0..self.ndims.unwrap())
|
||||||
|
.map(|i| {
|
||||||
|
let dim = unsafe {
|
||||||
|
self.strides().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&self.llvm_usize.const_int(i, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap()
|
||||||
|
})
|
||||||
|
.map(|obj| obj.as_basic_value_enum())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
TupleType::new(
|
||||||
|
generator,
|
||||||
|
ctx.ctx,
|
||||||
|
&repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(),
|
||||||
|
)
|
||||||
|
.construct_from_objects(ctx, objects, None)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn is_unsized(&self) -> Option<bool> {
|
pub fn is_unsized(&self) -> Option<bool> {
|
||||||
self.ndims.map(|ndims| ndims == 0)
|
self.ndims.map(|ndims| ndims == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
|
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
|
||||||
/// Otherwise, do nothing and return the ndarray itself.
|
/// Otherwise, do nothing and return the ndarray itself.
|
||||||
// TODO: Rename to get_unsized_element
|
// TODO: Rename to get_unsized_element
|
||||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -14,6 +14,7 @@ use crate::{
|
|||||||
builtin_fns,
|
builtin_fns,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
|
types::ndarray::NDArrayType,
|
||||||
values::{ProxyValue, RangeValue},
|
values::{ProxyValue, RangeValue},
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
@ -368,6 +369,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpEye
|
| PrimDef::FunNpEye
|
||||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||||
|
self.build_ndarray_property_getter_function(prim)
|
||||||
|
}
|
||||||
|
|
||||||
PrimDef::FunStr => self.build_str_function(),
|
PrimDef::FunStr => self.build_str_function(),
|
||||||
|
|
||||||
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
||||||
@ -1242,6 +1247,54 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]);
|
||||||
|
|
||||||
|
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.primitives.ndarray],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||||
|
// The function signatures of `np_shape` an `np_size` are the same.
|
||||||
|
// Mixed together for convenience.
|
||||||
|
|
||||||
|
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
|
||||||
|
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&into_var_map([in_ndarray_ty]),
|
||||||
|
prim.name(),
|
||||||
|
ret_ty,
|
||||||
|
&[(in_ndarray_ty.ty, "a")],
|
||||||
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty)
|
||||||
|
.map_value(ndarray.into_pointer_value(), None);
|
||||||
|
|
||||||
|
let result_tuple = match prim {
|
||||||
|
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
|
||||||
|
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(result_tuple.as_base_value().into()))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build the `str()` function.
|
/// Build the `str()` function.
|
||||||
fn build_str_function(&mut self) -> TopLevelDef {
|
fn build_str_function(&mut self) -> TopLevelDef {
|
||||||
let prim = PrimDef::FunStr;
|
let prim = PrimDef::FunStr;
|
||||||
|
@ -54,6 +54,10 @@ pub enum PrimDef {
|
|||||||
FunNpEye,
|
FunNpEye,
|
||||||
FunNpIdentity,
|
FunNpIdentity,
|
||||||
|
|
||||||
|
// NumPy ndarray property getters
|
||||||
|
FunNpShape,
|
||||||
|
FunNpStrides,
|
||||||
|
|
||||||
// Miscellaneous NumPy & SciPy functions
|
// Miscellaneous NumPy & SciPy functions
|
||||||
FunNpRound,
|
FunNpRound,
|
||||||
FunNpFloor,
|
FunNpFloor,
|
||||||
@ -240,6 +244,10 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpEye => fun("np_eye", None),
|
PrimDef::FunNpEye => fun("np_eye", None),
|
||||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||||
|
|
||||||
|
// NumPy NDArray property getters,
|
||||||
|
PrimDef::FunNpShape => fun("np_shape", None),
|
||||||
|
PrimDef::FunNpStrides => fun("np_strides", None),
|
||||||
|
|
||||||
// Miscellaneous NumPy & SciPy functions
|
// Miscellaneous NumPy & SciPy functions
|
||||||
PrimDef::FunNpRound => fun("np_round", None),
|
PrimDef::FunNpRound => fun("np_round", None),
|
||||||
PrimDef::FunNpFloor => fun("np_floor", None),
|
PrimDef::FunNpFloor => fun("np_floor", None),
|
||||||
|
@ -8,5 +8,5 @@ expression: res_vec
|
|||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -3,7 +3,7 @@ use std::{
|
|||||||
cmp::max,
|
cmp::max,
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
convert::{From, TryInto},
|
convert::{From, TryInto},
|
||||||
iter::once,
|
iter::{once, repeat_n},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
|
||||||
|
let ndarray = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
|
||||||
|
|
||||||
|
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
|
||||||
|
let ret_ty = TypeEnum::TTuple {
|
||||||
|
ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(),
|
||||||
|
is_vararg_ctx: false,
|
||||||
|
};
|
||||||
|
let ret_ty = self.unifier.add_ty(ret_ty);
|
||||||
|
|
||||||
|
let func_ty = TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![FuncArg {
|
||||||
|
name: "a".into(),
|
||||||
|
default_value: None,
|
||||||
|
ty: ndarray.custom.unwrap(),
|
||||||
|
is_vararg: false,
|
||||||
|
}],
|
||||||
|
ret: ret_ty,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
});
|
||||||
|
let func_ty = self.unifier.add_ty(func_ty);
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret_ty),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(func_ty),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![ndarray],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
if id == &"np_dot".into() {
|
if id == &"np_dot".into() {
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let arg1 = self.fold_expr(args.remove(0))?;
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
@ -179,6 +179,10 @@ def patch(module):
|
|||||||
module.np_identity = np.identity
|
module.np_identity = np.identity
|
||||||
module.np_array = np.array
|
module.np_array = np.array
|
||||||
|
|
||||||
|
# NumPy NDArray property getters
|
||||||
|
module.np_shape = np.shape
|
||||||
|
module.np_strides = lambda ndarray: ndarray.strides
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
|
Loading…
Reference in New Issue
Block a user