1
0
forked from M-Labs/nac3

core: Implement codegen for indexing into ndarray

This commit is contained in:
David Mak 2024-02-19 17:10:18 +08:00
parent 0d5c53e60c
commit cc538d221a

View File

@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
classes::{ListValue, RangeValue},
classes::{ListValue, NDArrayValue, RangeValue},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check,
get_llvm_type,
@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
}
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
ndims: Type,
v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_void = ctx.ctx.void_type();
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
unreachable!()
};
let ndims = values.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::I32(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
SymbolValue::I64(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
assert!(!ndims.is_empty());
let ndarray_ty_enum = TypeEnum::TNDArray {
ty,
ndims: ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
None,
),
};
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
Ok(Some(v.get_data()
.get_const(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
None,
)
.into()))
} else {
// Accessing an element from a multi-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None
)?;
let ndarray = NDArrayValue::from_ptr_val(
subscripted_ndarray,
llvm_usize,
None
);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
let memcpy_fn_name = format!(
"llvm.memcpy.p0i8.p0i8.i{}",
generator.get_size_type(ctx.ctx).get_bit_width(),
);
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pi8.into(),
llvm_pi8.into(),
llvm_usize.into(),
llvm_i1.into(),
],
false,
);
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.get_dims().ptr_offset(
ctx,
generator,
llvm_usize.const_int(1, false),
None,
);
ctx.builder.build_call(
memcpy_fn,
&[
ctx.builder.build_bitcast(
ndarray.get_dims().get_ptr(ctx),
llvm_pi8,
"",
).into(),
ctx.builder.build_bitcast(
v_dims_src_ptr,
llvm_pi8,
"",
).into(),
ctx.builder.build_int_mul(
ndarray_num_dims.into(),
llvm_usize.size_of(),
"",
).into(),
llvm_i1.const_zero().into(),
],
"",
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.get_dims().get_ptr(ctx),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.get_data().ptr_offset_const(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
None
);
ctx.builder.build_call(
memcpy_fn,
&[
ctx.builder.build_bitcast(
ndarray.get_data().get_ptr(ctx),
llvm_pi8,
"",
).into(),
ctx.builder.build_bitcast(
v_data_src_ptr,
llvm_pi8,
"",
).into(),
ctx.builder.build_int_mul(
ndarray_num_elems.into(),
llvm_ndarray_data_t.size_of().unwrap(),
"",
).into(),
llvm_i1.const_zero().into(),
],
"",
);
Ok(Some(v.get_ptr().into()))
}
}
/// See [`CodeGenerator::gen_expr`].
pub fn gen_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v.get_data().get(ctx, generator, index, None).into()
}
}
TypeEnum::TNDArray { .. } => {
return Err(String::from("subscript operator for ndarray not implemented"))
TypeEnum::TNDArray { ty, ndims } => {
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
} else {
return Ok(None)
};
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(
generator,
ctx,
*ty,
*ndims,
v,
&*slice,
)
}
TypeEnum::TTuple { .. } => {
let index: u32 =