diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs
index aa7f2a4d..e06ed4d2 100644
--- a/nac3artiq/src/symbol_resolver.rs
+++ b/nac3artiq/src/symbol_resolver.rs
@@ -1,6 +1,14 @@
-use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
+use inkwell::{
+ types::{BasicType, BasicTypeEnum},
+ values::BasicValueEnum,
+ AddressSpace,
+};
+use itertools::Itertools;
use nac3core::{
- codegen::{CodeGenContext, CodeGenerator},
+ codegen::{
+ classes::{NDArrayType, ProxyType},
+ CodeGenContext, CodeGenerator,
+ },
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
@@ -670,7 +678,7 @@ impl InnerResolver {
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
- let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
+ let len: usize = obj.getattr("ndim")?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(ty),
@@ -679,10 +687,10 @@ impl InnerResolver {
));
Ok(Ok(extracted_ty))
} else {
- let actual_ty =
- self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
- match actual_ty {
- Ok(t) => match unifier.unify(ty, t) {
+ let dtype = obj.getattr("dtype")?.getattr("type")?;
+ let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?;
+ match dtype_ty {
+ Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => {
let ndarray_ty =
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
@@ -966,7 +974,143 @@ impl InnerResolver {
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.ndarray {
- todo!()
+ let id_str = id.to_string();
+
+ if let Some(global) = ctx.module.get_global(&id_str) {
+ return Ok(Some(global.as_pointer_value().into()));
+ }
+
+ let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
+ {
+ expected_ty
+ } else {
+ unreachable!("must be ndarray")
+ };
+ let (ndarray_dtype, ndarray_ndims) =
+ unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
+
+ let llvm_usize = generator.get_size_type(ctx.ctx);
+ let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
+ let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
+
+ {
+ if self.global_value_ids.read().contains_key(&id) {
+ let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
+ ctx.module.add_global(
+ ndarray_llvm_ty.as_underlying_type(),
+ Some(AddressSpace::default()),
+ &id_str,
+ )
+ });
+ return Ok(Some(global.as_pointer_value().into()));
+ }
+ self.global_value_ids.write().insert(id, obj.into());
+ }
+
+ let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
+ else {
+ unreachable!("Expected Literal for ndarray_ndims")
+ };
+
+ let ndarray_ndims = if values.len() == 1 {
+ values[0].clone()
+ } else {
+ todo!("Unpacking literal of more than one element unimplemented")
+ };
+ let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
+ unreachable!("Expected u64 value for ndarray_ndims")
+ };
+
+ // Obtain the shape of the ndarray
+ let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
+ assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
+ let shape_values: Result