core: Add multidimensional array helpers

This commit is contained in:
David Mak 2024-06-04 17:50:09 +08:00
parent fa8af37e84
commit dced3c2407
2 changed files with 89 additions and 1 deletions

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
@ -691,3 +692,35 @@ pub fn parse_parameter_default_value(
]))
}
}
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray =>
unpack_ndarray_var_tys(unifier, ty).0,
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
_ => ty
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
};
if values.len() > 1 {
todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented")
}
u64::try_from(values[0].clone()).unwrap()
}
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
_ => 0
}
}

View File

@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
@ -1254,6 +1254,61 @@ impl<'a> Inferencer<'a> {
}))
}
// 3-argument ndarray n-dimensional creation functions
if id == &"np_array".into() && args.len() == 3 {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
let arg2 = self.fold_expr(args.remove(0))?;
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
let ndims = arraylike_get_ndims(self.unifier, arg0.custom.unwrap());
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
);
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "object".into(),
ty: arg0.custom.unwrap(),
default_value: None
},
FuncArg {
name: "copy".into(),
ty: arg1.custom.unwrap(),
default_value: Some(SymbolValue::Bool(true)),
},
FuncArg {
name: "ndmin".into(),
ty: arg2.custom.unwrap(),
default_value: Some(SymbolValue::U32(0)),
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0, arg1, arg2],
keywords: vec![],
},
}))
}
Ok(None)
}