forked from M-Labs/nac3
core: Implement codegen for indexing into ndarray
This commit is contained in:
parent
0d5c53e60c
commit
cc538d221a
|
@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{ListValue, RangeValue},
|
||||
classes::{ListValue, NDArrayValue, RangeValue},
|
||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||
gen_in_range_check,
|
||||
get_llvm_type,
|
||||
|
@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `v` - The `NDArray` value.
|
||||
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
ndims: Type,
|
||||
v: NDArrayValue<'ctx>,
|
||||
slice: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let ndims = values.iter()
|
||||
.map(|ndim| match *ndim {
|
||||
SymbolValue::U64(v) => Ok(v),
|
||||
SymbolValue::U32(v) => Ok(v as u64),
|
||||
SymbolValue::I32(v) => u64::try_from(v)
|
||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||
SymbolValue::I64(v) => u64::try_from(v)
|
||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
assert!(!ndims.is_empty());
|
||||
|
||||
let ndarray_ty_enum = TypeEnum::TNDArray {
|
||||
ty,
|
||||
ndims: ctx.unifier.get_fresh_literal(
|
||||
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||
None,
|
||||
),
|
||||
};
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||
|
||||
// Check that len is non-zero
|
||||
let len = v.load_ndims(ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||
[Some(len), None, None],
|
||||
slice.location,
|
||||
);
|
||||
|
||||
if ndims.len() == 1 && ndims[0] == 1 {
|
||||
// Accessing an element from a 1-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
Ok(Some(v.get_data()
|
||||
.get_const(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.ctx.i32_type().const_array(&[index]),
|
||||
None,
|
||||
)
|
||||
.into()))
|
||||
} else {
|
||||
// Accessing an element from a multi-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||
// elements over
|
||||
let subscripted_ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(
|
||||
subscripted_ndarray,
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
|
||||
let num_dims = v.load_ndims(ctx);
|
||||
ndarray.store_ndims(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""),
|
||||
);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
let memcpy_fn_name = format!(
|
||||
"llvm.memcpy.p0i8.p0i8.i{}",
|
||||
generator.get_size_type(ctx.ctx).get_bit_width(),
|
||||
);
|
||||
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_pi8.into(),
|
||||
llvm_pi8.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_i1.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let v_dims_src_ptr = v.get_dims().ptr_offset(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_int(1, false),
|
||||
None,
|
||||
);
|
||||
ctx.builder.build_call(
|
||||
memcpy_fn,
|
||||
&[
|
||||
ctx.builder.build_bitcast(
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_bitcast(
|
||||
v_dims_src_ptr,
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_int_mul(
|
||||
ndarray_num_dims.into(),
|
||||
llvm_usize.size_of(),
|
||||
"",
|
||||
).into(),
|
||||
llvm_i1.const_zero().into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
let v_data_src_ptr = v.get_data().ptr_offset_const(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.ctx.i32_type().const_array(&[index]),
|
||||
None
|
||||
);
|
||||
ctx.builder.build_call(
|
||||
memcpy_fn,
|
||||
&[
|
||||
ctx.builder.build_bitcast(
|
||||
ndarray.get_data().get_ptr(ctx),
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_bitcast(
|
||||
v_data_src_ptr,
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_int_mul(
|
||||
ndarray_num_elems.into(),
|
||||
llvm_ndarray_data_t.size_of().unwrap(),
|
||||
"",
|
||||
).into(),
|
||||
llvm_i1.const_zero().into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
Ok(Some(v.get_ptr().into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
v.get_data().get(ctx, generator, index, None).into()
|
||||
}
|
||||
}
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
TypeEnum::TNDArray { ty, ndims } => {
|
||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let v = NDArrayValue::from_ptr_val(v, usize, None);
|
||||
|
||||
return gen_ndarray_subscript_expr(
|
||||
generator,
|
||||
ctx,
|
||||
*ty,
|
||||
*ndims,
|
||||
v,
|
||||
&*slice,
|
||||
)
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
|
|
Loading…
Reference in New Issue