core/typecheck: First btis of NumPy-like array type inference
For readability of the codebase, I chose ndarray for the name of the type, while [numpy.]array() is the name of the most commonly used constructor.
This commit is contained in:
parent
8454741f9e
commit
72cb693e2e
@ -871,6 +871,58 @@ impl<'a> Inferencer<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// array() is a "magic" function call that determines the number of
|
||||||
|
// dimensions in the result from the nesting of the array argument type,
|
||||||
|
// to match the host Python NumPy API.
|
||||||
|
if id == "array".into() {
|
||||||
|
if args.is_empty() {
|
||||||
|
return report_error(
|
||||||
|
"`array()` expects at least one argument (contents in list form)",
|
||||||
|
func_location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.len() > 2 || !keywords.is_empty() {
|
||||||
|
return report_error(
|
||||||
|
"Additional `array()` arguments not yet implemented",
|
||||||
|
func_location,
|
||||||
|
);
|
||||||
|
// TODO: Implement `dtype=` kwarg.
|
||||||
|
}
|
||||||
|
|
||||||
|
let list_arg = self.fold_expr(args.remove(0))?;
|
||||||
|
// TODO: Implement special case for emtpy arrays (e.g. `array([[]]))`)
|
||||||
|
// to match NumPy.
|
||||||
|
|
||||||
|
let mut num_dims = 0;
|
||||||
|
let mut elem_type = list_arg.custom.unwrap();
|
||||||
|
while let TypeEnum::TList { ty } = &*self.unifier.get_ty(elem_type) {
|
||||||
|
elem_type = *ty;
|
||||||
|
num_dims += 1;
|
||||||
|
}
|
||||||
|
if num_dims == 0 {
|
||||||
|
return report_error(
|
||||||
|
"expected list argument to array(), not xxx",
|
||||||
|
func_location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let custom =
|
||||||
|
Some(self.unifier.add_ty(TypeEnum::TNDArray { ty: elem_type, num_dims }));
|
||||||
|
return Ok(Located {
|
||||||
|
location,
|
||||||
|
custom,
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: None,
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id, ctx },
|
||||||
|
}),
|
||||||
|
args: vec![list_arg],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
||||||
} else {
|
} else {
|
||||||
func
|
func
|
||||||
|
@ -513,6 +513,14 @@ impl TestEnvironment {
|
|||||||
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
|
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
|
||||||
&[]
|
&[]
|
||||||
; "listcomp test")]
|
; "listcomp test")]
|
||||||
|
#[test_case(
|
||||||
|
indoc! {"
|
||||||
|
a = array([1, 2])
|
||||||
|
b = array([[1, 2], [3, 4]])
|
||||||
|
"},
|
||||||
|
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(),
|
||||||
|
&[]
|
||||||
|
; "array test")]
|
||||||
#[test_case(indoc! {"
|
#[test_case(indoc! {"
|
||||||
a = virtual(Bar(), Bar)
|
a = virtual(Bar(), Bar)
|
||||||
b = a.b()
|
b = a.b()
|
||||||
@ -533,6 +541,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
|||||||
let mut env = TestEnvironment::new();
|
let mut env = TestEnvironment::new();
|
||||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||||
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
|
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
|
||||||
|
defined_identifiers.insert("array".into());
|
||||||
defined_identifiers.insert("virtual".into());
|
defined_identifiers.insert("virtual".into());
|
||||||
let mut inferencer = env.get_inferencer();
|
let mut inferencer = env.get_inferencer();
|
||||||
inferencer.defined_identifiers = defined_identifiers.clone();
|
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||||
|
@ -137,6 +137,10 @@ pub enum TypeEnum {
|
|||||||
TList {
|
TList {
|
||||||
ty: Type,
|
ty: Type,
|
||||||
},
|
},
|
||||||
|
TNDArray {
|
||||||
|
ty: Type,
|
||||||
|
num_dims: u8,
|
||||||
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: Mapping<StrRef, (Type, bool)>,
|
fields: Mapping<StrRef, (Type, bool)>,
|
||||||
@ -156,6 +160,7 @@ impl TypeEnum {
|
|||||||
TypeEnum::TVar { .. } => "TVar",
|
TypeEnum::TVar { .. } => "TVar",
|
||||||
TypeEnum::TTuple { .. } => "TTuple",
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
TypeEnum::TList { .. } => "TList",
|
TypeEnum::TList { .. } => "TList",
|
||||||
|
TypeEnum::TNDArray { .. } => "TNDArray",
|
||||||
TypeEnum::TObj { .. } => "TObj",
|
TypeEnum::TObj { .. } => "TObj",
|
||||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||||
TypeEnum::TCall { .. } => "TCall",
|
TypeEnum::TCall { .. } => "TCall",
|
||||||
@ -387,6 +392,7 @@ impl Unifier {
|
|||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
||||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
TObj { params: vars, .. } => {
|
TObj { params: vars, .. } => {
|
||||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||||
@ -904,6 +910,13 @@ impl Unifier {
|
|||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
||||||
}
|
}
|
||||||
|
TypeEnum::TNDArray { ty, num_dims } => {
|
||||||
|
format!(
|
||||||
|
"ndarray[{}, {}]",
|
||||||
|
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
|
||||||
|
num_dims
|
||||||
|
)
|
||||||
|
}
|
||||||
TypeEnum::TVirtual { ty } => {
|
TypeEnum::TVirtual { ty } => {
|
||||||
format!(
|
format!(
|
||||||
"virtual[{}]",
|
"virtual[{}]",
|
||||||
|
Loading…
Reference in New Issue
Block a user