[artiq] codegen: Reimplement polymorphic_print for strided ndarray

Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
This commit is contained in:
David Mak 2024-11-29 16:54:31 +08:00
parent bbc68b8b1a
commit b40e9bca28
1 changed files with 39 additions and 53 deletions

View File

@ -16,14 +16,13 @@ use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::{
codegen::{
expr::{destructure_range, gen_call},
irrt::ndarray::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
type_aligned_alloca,
types::NDArrayType,
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
RangeValue, UntypedArrayLikeAccessor,
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue,
UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
},
@ -1339,63 +1338,50 @@ fn polymorphic_print<'ctx>(
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_pointer_value(
value.into_pointer_value(),
llvm_elem_ty,
None,
llvm_usize,
None,
);
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
.map_value(value.into_pointer_value(), None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
let num_0 = llvm_usize.const_zero();
polymorphic_print(
ctx,
generator,
&[(elem_ty, elem.into())],
"",
None,
true,
as_rtio,
)?;
// Print `ndarray` as a flat list delimited by interspersed with ", \0"
ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
let i = hdl.get_index(ctx);
let scalar = hdl.get_scalar(ctx);
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::ULT, i, last, "")
.unwrap())
},
|generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default());
// if (i != 0) puts(", ");
gen_if_callback(
generator,
ctx,
|_, ctx| {
let not_first = ctx
.builder
.build_int_compare(IntPredicate::NE, i, num_0, "")
.unwrap();
Ok(not_first)
},
|generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default());
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
// Print element
polymorphic_print(
ctx,
generator,
&[(dtype, scalar.into())],
"",
None,
true,
as_rtio,
)?;
Ok(())
})?;
fmt.push_str(")]");
flush(ctx, generator, &mut fmt, &mut args);