forked from M-Labs/nac3
core/type_inferencer: Add special rule for np_array
This commit is contained in:
parent
82bf14785b
commit
c2fdb12397
|
@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
TopLevelContext,
|
TopLevelContext,
|
||||||
},
|
},
|
||||||
|
@ -1254,6 +1254,77 @@ impl<'a> Inferencer<'a> {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 1-argument ndarray n-dimensional creation functions
|
||||||
|
if id == &"np_array".into() && args.len() == 1 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let keywords = keywords.iter()
|
||||||
|
.map(|v| fold::fold_keyword(self, v.clone()))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let ndmin_kw = keywords.iter()
|
||||||
|
.find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
|
||||||
|
|
||||||
|
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
|
||||||
|
let ndims = if let Some(ndmin_kw) = ndmin_kw {
|
||||||
|
match &ndmin_kw.node.value.node {
|
||||||
|
ExprKind::Constant { value, .. } => match value {
|
||||||
|
ast::Constant::Int(value) => *value as u64,
|
||||||
|
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
||||||
|
};
|
||||||
|
let ndims = self.unifier.get_fresh_literal(
|
||||||
|
vec![SymbolValue::U64(ndims)],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let ret = make_ndarray_ty(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "object".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None
|
||||||
|
},
|
||||||
|
FuncArg {
|
||||||
|
name: "copy".into(),
|
||||||
|
ty: self.primitives.bool,
|
||||||
|
default_value: Some(SymbolValue::Bool(true)),
|
||||||
|
},
|
||||||
|
FuncArg {
|
||||||
|
name: "ndmin".into(),
|
||||||
|
ty: self.primitives.int32,
|
||||||
|
default_value: Some(SymbolValue::U32(0)),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1264,11 +1335,10 @@ impl<'a> Inferencer<'a> {
|
||||||
mut args: Vec<ast::Expr<()>>,
|
mut args: Vec<ast::Expr<()>>,
|
||||||
keywords: Vec<Located<ast::KeywordData>>,
|
keywords: Vec<Located<ast::KeywordData>>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
||||||
let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
|
if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
|
||||||
return Ok(spec_call_func)
|
return Ok(spec_call_func)
|
||||||
} else {
|
}
|
||||||
func
|
|
||||||
};
|
|
||||||
let func = Box::new(self.fold_expr(func)?);
|
let func = Box::new(self.fold_expr(func)?);
|
||||||
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||||
let keywords = keywords
|
let keywords = keywords
|
||||||
|
|
Loading…
Reference in New Issue