[artiq] codegen: Reimplement polymorphic_print for strided ndarray

Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
This commit is contained in:
David Mak 2024-11-29 16:54:31 +08:00
parent bbc68b8b1a
commit b40e9bca28
1 changed files with 39 additions and 53 deletions

View File

@ -16,14 +16,13 @@ use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::ndarray::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
type_aligned_alloca, type_aligned_alloca,
types::NDArrayType, types::NDArrayType,
values::{ values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue,
RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -1339,63 +1338,50 @@ fn polymorphic_print<'ctx>(
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
fmt.push_str("array(["); fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_pointer_value( let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
value.into_pointer_value(), let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
llvm_elem_ty, .map_value(value.into_pointer_value(), None);
None,
llvm_usize,
None,
);
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
gen_for_callback_incrementing( let num_0 = llvm_usize.const_zero();
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
polymorphic_print( // Print `ndarray` as a flat list delimited by interspersed with ", \0"
ctx, ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
generator, let i = hdl.get_index(ctx);
&[(elem_ty, elem.into())], let scalar = hdl.get_scalar(ctx);
"",
None,
true,
as_rtio,
)?;
gen_if_callback( // if (i != 0) puts(", ");
generator, gen_if_callback(
ctx, generator,
|_, ctx| { ctx,
Ok(ctx |_, ctx| {
.builder let not_first = ctx
.build_int_compare(IntPredicate::ULT, i, last, "") .builder
.unwrap()) .build_int_compare(IntPredicate::NE, i, num_0, "")
}, .unwrap();
|generator, ctx| { Ok(not_first)
printf(ctx, generator, ", \0".into(), Vec::default()); },
|generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default());
Ok(())
},
|_, _| Ok(()),
)?;
Ok(()) // Print element
}, polymorphic_print(
|_, _| Ok(()), ctx,
)?; generator,
&[(dtype, scalar.into())],
Ok(()) "",
}, None,
llvm_usize.const_int(1, false), true,
)?; as_rtio,
)?;
Ok(())
})?;
fmt.push_str(")]"); fmt.push_str(")]");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);