core/typecheck: First btis of NumPy-like array type inference

For readability of the codebase, I chose ndarray for the name of the
type, while [numpy.]array() is the name of the most commonly used
constructor.
This commit is contained in:
David Nadlinger 2022-04-22 20:49:18 +01:00
parent 8454741f9e
commit 72cb693e2e
3 changed files with 74 additions and 0 deletions

View File

@ -871,6 +871,58 @@ impl<'a> Inferencer<'a> {
} }
} }
} }
// array() is a "magic" function call that determines the number of
// dimensions in the result from the nesting of the array argument type,
// to match the host Python NumPy API.
if id == "array".into() {
if args.is_empty() {
return report_error(
"`array()` expects at least one argument (contents in list form)",
func_location,
);
}
if args.len() > 2 || !keywords.is_empty() {
return report_error(
"Additional `array()` arguments not yet implemented",
func_location,
);
// TODO: Implement `dtype=` kwarg.
}
let list_arg = self.fold_expr(args.remove(0))?;
// TODO: Implement special case for emtpy arrays (e.g. `array([[]]))`)
// to match NumPy.
let mut num_dims = 0;
let mut elem_type = list_arg.custom.unwrap();
while let TypeEnum::TList { ty } = &*self.unifier.get_ty(elem_type) {
elem_type = *ty;
num_dims += 1;
}
if num_dims == 0 {
return report_error(
"expected list argument to array(), not xxx",
func_location,
);
}
let custom =
Some(self.unifier.add_ty(TypeEnum::TNDArray { ty: elem_type, num_dims }));
return Ok(Located {
location,
custom,
node: ExprKind::Call {
func: Box::new(Located {
custom: None,
location: func.location,
node: ExprKind::Name { id, ctx },
}),
args: vec![list_arg],
keywords: vec![],
},
});
}
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
} else { } else {
func func

View File

@ -513,6 +513,14 @@ impl TestEnvironment {
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(), [("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
&[] &[]
; "listcomp test")] ; "listcomp test")]
#[test_case(
indoc! {"
a = array([1, 2])
b = array([[1, 2], [3, 4]])
"},
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(),
&[]
; "array test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = virtual(Bar(), Bar) a = virtual(Bar(), Bar)
b = a.b() b = a.b()
@ -533,6 +541,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("array".into());
defined_identifiers.insert("virtual".into()); defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers = defined_identifiers.clone();

View File

@ -137,6 +137,10 @@ pub enum TypeEnum {
TList { TList {
ty: Type, ty: Type,
}, },
TNDArray {
ty: Type,
num_dims: u8,
},
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: Mapping<StrRef, (Type, bool)>, fields: Mapping<StrRef, (Type, bool)>,
@ -156,6 +160,7 @@ impl TypeEnum {
TypeEnum::TVar { .. } => "TVar", TypeEnum::TVar { .. } => "TVar",
TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList", TypeEnum::TList { .. } => "TList",
TypeEnum::TNDArray { .. } => "TNDArray",
TypeEnum::TObj { .. } => "TObj", TypeEnum::TObj { .. } => "TObj",
TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall", TypeEnum::TCall { .. } => "TCall",
@ -387,6 +392,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } => self.is_concrete(*ty, allowed_typevars),
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
@ -904,6 +910,13 @@ impl Unifier {
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
} }
TypeEnum::TNDArray { ty, num_dims } => {
format!(
"ndarray[{}, {}]",
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
num_dims
)
}
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
format!( format!(
"virtual[{}]", "virtual[{}]",