forked from M-Labs/nac3
core: Implement codegen for indexing into ndarray
This commit is contained in:
parent
0d5c53e60c
commit
cc538d221a
|
@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{ListValue, RangeValue},
|
classes::{ListValue, NDArrayValue, RangeValue},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
get_llvm_type,
|
get_llvm_type,
|
||||||
|
@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates code for a subscript expression on an `ndarray`.
|
||||||
|
///
|
||||||
|
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||||
|
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||||
|
/// * `v` - The `NDArray` value.
|
||||||
|
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||||
|
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: Type,
|
||||||
|
ndims: Type,
|
||||||
|
v: NDArrayValue<'ctx>,
|
||||||
|
slice: &Expr<Option<Type>>,
|
||||||
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndims = values.iter()
|
||||||
|
.map(|ndim| match *ndim {
|
||||||
|
SymbolValue::U64(v) => Ok(v),
|
||||||
|
SymbolValue::U32(v) => Ok(v as u64),
|
||||||
|
SymbolValue::I32(v) => u64::try_from(v)
|
||||||
|
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||||
|
SymbolValue::I64(v) => u64::try_from(v)
|
||||||
|
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
|
let ndarray_ty_enum = TypeEnum::TNDArray {
|
||||||
|
ty,
|
||||||
|
ndims: ctx.unifier.get_fresh_literal(
|
||||||
|
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||||
|
|
||||||
|
// Check that len is non-zero
|
||||||
|
let len = v.load_ndims(ctx);
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""),
|
||||||
|
"0:IndexError",
|
||||||
|
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||||
|
[Some(len), None, None],
|
||||||
|
slice.location,
|
||||||
|
);
|
||||||
|
|
||||||
|
if ndims.len() == 1 && ndims[0] == 1 {
|
||||||
|
// Accessing an element from a 1-dimensional `ndarray`
|
||||||
|
|
||||||
|
if let ExprKind::Slice { .. } = &slice.node {
|
||||||
|
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(v.get_data()
|
||||||
|
.get_const(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.ctx.i32_type().const_array(&[index]),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.into()))
|
||||||
|
} else {
|
||||||
|
// Accessing an element from a multi-dimensional `ndarray`
|
||||||
|
|
||||||
|
if let ExprKind::Slice { .. } = &slice.node {
|
||||||
|
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
|
// elements over
|
||||||
|
let subscripted_ndarray = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_ndarray_t.into(),
|
||||||
|
None
|
||||||
|
)?;
|
||||||
|
let ndarray = NDArrayValue::from_ptr_val(
|
||||||
|
subscripted_ndarray,
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
let num_dims = v.load_ndims(ctx);
|
||||||
|
ndarray.store_ndims(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
|
let memcpy_fn_name = format!(
|
||||||
|
"llvm.memcpy.p0i8.p0i8.i{}",
|
||||||
|
generator.get_size_type(ctx.ctx).get_bit_width(),
|
||||||
|
);
|
||||||
|
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pi8.into(),
|
||||||
|
llvm_pi8.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_i1.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let v_dims_src_ptr = v.get_dims().ptr_offset(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
ctx.builder.build_call(
|
||||||
|
memcpy_fn,
|
||||||
|
&[
|
||||||
|
ctx.builder.build_bitcast(
|
||||||
|
ndarray.get_dims().get_ptr(ctx),
|
||||||
|
llvm_pi8,
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
ctx.builder.build_bitcast(
|
||||||
|
v_dims_src_ptr,
|
||||||
|
llvm_pi8,
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
ctx.builder.build_int_mul(
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
llvm_usize.size_of(),
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
llvm_i1.const_zero().into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray.load_ndims(ctx),
|
||||||
|
ndarray.get_dims().get_ptr(ctx),
|
||||||
|
);
|
||||||
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
let v_data_src_ptr = v.get_data().ptr_offset_const(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.ctx.i32_type().const_array(&[index]),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
ctx.builder.build_call(
|
||||||
|
memcpy_fn,
|
||||||
|
&[
|
||||||
|
ctx.builder.build_bitcast(
|
||||||
|
ndarray.get_data().get_ptr(ctx),
|
||||||
|
llvm_pi8,
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
ctx.builder.build_bitcast(
|
||||||
|
v_data_src_ptr,
|
||||||
|
llvm_pi8,
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
ctx.builder.build_int_mul(
|
||||||
|
ndarray_num_elems.into(),
|
||||||
|
llvm_ndarray_data_t.size_of().unwrap(),
|
||||||
|
"",
|
||||||
|
).into(),
|
||||||
|
llvm_i1.const_zero().into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Some(v.get_ptr().into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_expr`].
|
/// See [`CodeGenerator::gen_expr`].
|
||||||
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
v.get_data().get(ctx, generator, index, None).into()
|
v.get_data().get(ctx, generator, index, None).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { .. } => {
|
TypeEnum::TNDArray { ty, ndims } => {
|
||||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
let v = NDArrayValue::from_ptr_val(v, usize, None);
|
||||||
|
|
||||||
|
return gen_ndarray_subscript_expr(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
*ty,
|
||||||
|
*ndims,
|
||||||
|
v,
|
||||||
|
&*slice,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
|
|
Loading…
Reference in New Issue