core/typecheck: First btis of NumPy-like array type inference
For readability of the codebase, I chose ndarray for the name of the type, while [numpy.]array() is the name of the most commonly used constructor.
This commit is contained in:
parent
8454741f9e
commit
72cb693e2e
@ -871,6 +871,58 @@ impl<'a> Inferencer<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// array() is a "magic" function call that determines the number of
|
||||
// dimensions in the result from the nesting of the array argument type,
|
||||
// to match the host Python NumPy API.
|
||||
if id == "array".into() {
|
||||
if args.is_empty() {
|
||||
return report_error(
|
||||
"`array()` expects at least one argument (contents in list form)",
|
||||
func_location,
|
||||
);
|
||||
}
|
||||
|
||||
if args.len() > 2 || !keywords.is_empty() {
|
||||
return report_error(
|
||||
"Additional `array()` arguments not yet implemented",
|
||||
func_location,
|
||||
);
|
||||
// TODO: Implement `dtype=` kwarg.
|
||||
}
|
||||
|
||||
let list_arg = self.fold_expr(args.remove(0))?;
|
||||
// TODO: Implement special case for emtpy arrays (e.g. `array([[]]))`)
|
||||
// to match NumPy.
|
||||
|
||||
let mut num_dims = 0;
|
||||
let mut elem_type = list_arg.custom.unwrap();
|
||||
while let TypeEnum::TList { ty } = &*self.unifier.get_ty(elem_type) {
|
||||
elem_type = *ty;
|
||||
num_dims += 1;
|
||||
}
|
||||
if num_dims == 0 {
|
||||
return report_error(
|
||||
"expected list argument to array(), not xxx",
|
||||
func_location,
|
||||
);
|
||||
}
|
||||
|
||||
let custom =
|
||||
Some(self.unifier.add_ty(TypeEnum::TNDArray { ty: elem_type, num_dims }));
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: None,
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id, ctx },
|
||||
}),
|
||||
args: vec![list_arg],
|
||||
keywords: vec![],
|
||||
},
|
||||
});
|
||||
}
|
||||
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
||||
} else {
|
||||
func
|
||||
|
@ -513,6 +513,14 @@ impl TestEnvironment {
|
||||
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "listcomp test")]
|
||||
#[test_case(
|
||||
indoc! {"
|
||||
a = array([1, 2])
|
||||
b = array([[1, 2], [3, 4]])
|
||||
"},
|
||||
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "array test")]
|
||||
#[test_case(indoc! {"
|
||||
a = virtual(Bar(), Bar)
|
||||
b = a.b()
|
||||
@ -533,6 +541,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
let mut env = TestEnvironment::new();
|
||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
|
||||
defined_identifiers.insert("array".into());
|
||||
defined_identifiers.insert("virtual".into());
|
||||
let mut inferencer = env.get_inferencer();
|
||||
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||
|
@ -137,6 +137,10 @@ pub enum TypeEnum {
|
||||
TList {
|
||||
ty: Type,
|
||||
},
|
||||
TNDArray {
|
||||
ty: Type,
|
||||
num_dims: u8,
|
||||
},
|
||||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: Mapping<StrRef, (Type, bool)>,
|
||||
@ -156,6 +160,7 @@ impl TypeEnum {
|
||||
TypeEnum::TVar { .. } => "TVar",
|
||||
TypeEnum::TTuple { .. } => "TTuple",
|
||||
TypeEnum::TList { .. } => "TList",
|
||||
TypeEnum::TNDArray { .. } => "TNDArray",
|
||||
TypeEnum::TObj { .. } => "TObj",
|
||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||
TypeEnum::TCall { .. } => "TCall",
|
||||
@ -387,6 +392,7 @@ impl Unifier {
|
||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||
TCall { .. } => false,
|
||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||
TObj { params: vars, .. } => {
|
||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||
@ -904,6 +910,13 @@ impl Unifier {
|
||||
TypeEnum::TList { ty } => {
|
||||
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
||||
}
|
||||
TypeEnum::TNDArray { ty, num_dims } => {
|
||||
format!(
|
||||
"ndarray[{}, {}]",
|
||||
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
|
||||
num_dims
|
||||
)
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
format!(
|
||||
"virtual[{}]",
|
||||
|
Loading…
Reference in New Issue
Block a user