core/ndstrides: change get_llvm_type to new ndarray

This commit is contained in:
lyken 2024-08-05 10:35:51 +08:00
parent 876d1bbfe3
commit 1873b35654
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8

View File

@ -33,7 +33,7 @@ use std::sync::{
Arc, Arc,
}; };
use std::thread; use std::thread;
use structure::{cslice::CSlice, exception::Exception}; use structure::{cslice::CSlice, exception::Exception, ndarray::NpArray};
pub mod builtin_fns; pub mod builtin_fns;
pub mod classes; pub mod classes;
@ -45,6 +45,7 @@ pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod model; pub mod model;
pub mod numpy; pub mod numpy;
pub mod numpy_new;
pub mod stmt; pub mod stmt;
pub mod structure; pub mod structure;
@ -493,12 +494,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let tyctx = generator.type_context(ctx);
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into() let pndarray_model = PtrModel(StructModel(NpArray));
pndarray_model.get_type(tyctx, ctx).as_basic_type_enum()
} }
_ => unreachable!( _ => unreachable!(