artiq/symbol_resolver: Determine global array type by init-val type

This commit is contained in:
David Mak 2024-07-12 14:52:33 +08:00
parent 9b5fb69875
commit 8f95c707d7
1 changed files with 34 additions and 19 deletions

View File

@ -1073,6 +1073,8 @@ impl InnerResolver {
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
// TODO: Special handling required for strings, since there are two representations:
// struct %str and [n x i8].
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
@ -1140,31 +1142,44 @@ impl InnerResolver {
}) })
}) })
.collect(); .collect();
let data = data?.unwrap().into_iter(); let data = data?.unwrap();
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => { let make_llvm_array =
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) |llvm_ty: BasicTypeEnum<'ctx>, elems: Vec<BasicValueEnum<'ctx>>| {
} debug_assert!(elems.iter().all(|elem| elem.get_type() == llvm_ty));
BasicTypeEnum::IntType(ty) => { match llvm_ty {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) BasicTypeEnum::ArrayType(ty) => ty.const_array(
} &elems.into_iter().map(BasicValueEnum::into_array_value).collect_vec(),
),
BasicTypeEnum::PointerType(ty) => { BasicTypeEnum::FloatType(ty) => ty.const_array(
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) &elems.into_iter().map(BasicValueEnum::into_float_value).collect_vec(),
} ),
BasicTypeEnum::StructType(ty) => { BasicTypeEnum::IntType(ty) => ty.const_array(
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) &elems.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
} ),
BasicTypeEnum::PointerType(ty) => ty.const_array(
&elems
.into_iter()
.map(BasicValueEnum::into_pointer_value)
.collect_vec(),
),
BasicTypeEnum::StructType(ty) => ty.const_array(
&elems.into_iter().map(BasicValueEnum::into_struct_value).collect_vec(),
),
BasicTypeEnum::VectorType(_) => unreachable!(), BasicTypeEnum::VectorType(_) => unreachable!(),
}
}; };
let ndarray_dtype_llvm_ty =
if data.is_empty() { ndarray_dtype_llvm_ty } else { data[0].get_type() };
let data = make_llvm_array(ndarray_dtype_llvm_ty, data);
// create a global for ndarray.data and initialize it using the elements // create a global for ndarray.data and initialize it using the elements
let data_global = ctx.module.add_global( let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32), ndarray_dtype_llvm_ty.array_type(sz as u32),