forked from M-Labs/nac3
nac3core: add bound check for list slice
This commit is contained in:
parent
7443c5ea0f
commit
c29cbf6ddd
@ -238,6 +238,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
|
||||
pub fn gen_int_ops(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
op: &Operator,
|
||||
lhs: BasicValueEnum<'ctx>,
|
||||
rhs: BasicValueEnum<'ctx>,
|
||||
@ -273,7 +274,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
(Operator::RShift, _) => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(),
|
||||
(Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(),
|
||||
(Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").into(),
|
||||
(Operator::Pow, s) => integer_power(self, lhs, rhs, s).into(),
|
||||
(Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(),
|
||||
// special implementation?
|
||||
(Operator::MatMult, _) => unreachable!(),
|
||||
}
|
||||
@ -940,9 +941,9 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>(
|
||||
// which would be unchanged until further unification, which we would never do
|
||||
// when doing code generation for function instances
|
||||
Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
||||
ctx.gen_int_ops(op, left, right, true)
|
||||
ctx.gen_int_ops(generator, op, left, right, true)
|
||||
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
||||
ctx.gen_int_ops(op, left, right, false)
|
||||
ctx.gen_int_ops(generator, op, left, right, false)
|
||||
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
||||
ctx.gen_float_ops(op, left, right)
|
||||
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
||||
@ -1415,6 +1416,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
|
||||
let (start, end, step) =
|
||||
handle_slice_indices(lower, upper, step, ctx, generator, v)?;
|
||||
let length = calculate_len_for_slice_range(
|
||||
generator,
|
||||
ctx,
|
||||
start,
|
||||
ctx.builder
|
||||
@ -1436,8 +1438,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
|
||||
let res_ind =
|
||||
handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)?;
|
||||
list_slice_assignment(
|
||||
generator,
|
||||
ctx,
|
||||
generator.get_size_type(ctx.ctx),
|
||||
ty,
|
||||
res_array_ret,
|
||||
res_ind,
|
||||
|
@ -6,7 +6,7 @@ use inkwell::{
|
||||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
types::BasicTypeEnum,
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
@ -34,6 +34,7 @@ pub fn load_irrt(ctx: &Context) -> Module {
|
||||
// repeated squaring method adapted from GNU Scientific Library:
|
||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
pub fn integer_power<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
base: IntValue<'ctx>,
|
||||
exp: IntValue<'ctx>,
|
||||
@ -51,7 +52,21 @@ pub fn integer_power<'ctx, 'a>(
|
||||
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
||||
ctx.module.add_function(symbol, fn_type, None)
|
||||
});
|
||||
// TODO: throw exception when exp < 0
|
||||
// throw exception when exp < 0
|
||||
let ge_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
exp,
|
||||
exp.get_type().const_zero(),
|
||||
"assert_int_pow_ge_0",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ge_zero,
|
||||
"0:ValueError",
|
||||
"integer power must be positive or zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||
.try_as_basic_value()
|
||||
@ -60,6 +75,7 @@ pub fn integer_power<'ctx, 'a>(
|
||||
}
|
||||
|
||||
pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
start: IntValue<'ctx>,
|
||||
end: IntValue<'ctx>,
|
||||
@ -72,7 +88,21 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
// TODO: assert step != 0, throw exception if not
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"step must not be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||
.try_as_basic_value()
|
||||
@ -129,7 +159,6 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
list: PointerValue<'ctx>,
|
||||
) -> Result<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), String> {
|
||||
// TODO: throw exception when step is 0
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let one = int32.const_int(1, false);
|
||||
@ -156,6 +185,21 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>(
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator)?
|
||||
.into_int_value();
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"slice step cannot be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1");
|
||||
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg");
|
||||
(
|
||||
@ -231,14 +275,15 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>(
|
||||
/// Order of tuples assign_idx and value_idx is ('start', 'end', 'step').
|
||||
/// Negative index should be handled before entering this function
|
||||
pub fn list_slice_assignment<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
size_ty: IntType<'ctx>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
dest_arr: PointerValue<'ctx>,
|
||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
src_arr: PointerValue<'ctx>,
|
||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
) {
|
||||
let size_ty = generator.get_size_type(ctx.ctx);
|
||||
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::Generic);
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
||||
@ -282,8 +327,67 @@ pub fn list_slice_assignment<'ctx, 'a>(
|
||||
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
|
||||
|
||||
// index in bound and positive should be done
|
||||
// TODO: assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||
// throw exception if not satisfied
|
||||
let src_end = ctx.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SLT,
|
||||
src_idx.2,
|
||||
zero,
|
||||
"is_neg",
|
||||
),
|
||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"),
|
||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one"),
|
||||
"final_e",
|
||||
)
|
||||
.into_int_value();
|
||||
let dest_end = ctx.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SLT,
|
||||
dest_idx.2,
|
||||
zero,
|
||||
"is_neg",
|
||||
),
|
||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"),
|
||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"),
|
||||
"final_e",
|
||||
)
|
||||
.into_int_value();
|
||||
let src_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||
let dest_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||
let src_eq_dest = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
src_slice_len,
|
||||
dest_slice_len,
|
||||
"slice_src_eq_dest",
|
||||
);
|
||||
let src_slt_dest = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
src_slice_len,
|
||||
dest_slice_len,
|
||||
"slice_src_slt_dest",
|
||||
);
|
||||
let dest_step_eq_one = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
dest_idx.2,
|
||||
dest_idx.2.get_type().const_int(1, false),
|
||||
"slice_dest_step_eq_one",
|
||||
);
|
||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1");
|
||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond");
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cond,
|
||||
"0:ValueError",
|
||||
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
||||
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let new_len = {
|
||||
let args = vec![
|
||||
dest_idx.0.into(), // dest start idx
|
||||
|
@ -134,8 +134,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
|
||||
};
|
||||
let src_ind = handle_slice_indices(&None, &None, &None, ctx, generator, value)?;
|
||||
list_slice_assignment(
|
||||
generator,
|
||||
ctx,
|
||||
generator.get_size_type(ctx.ctx),
|
||||
ty,
|
||||
ls,
|
||||
(start, end, step),
|
||||
|
@ -670,8 +670,28 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||
step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)?);
|
||||
}
|
||||
}
|
||||
// TODO: error when step == 0
|
||||
let step = step.unwrap_or_else(|| int32.const_int(1, false).into());
|
||||
let step = match step {
|
||||
Some(step) => {
|
||||
let step = step.into_int_value();
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"range() step must not be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
step
|
||||
}
|
||||
None => int32.const_int(1, false),
|
||||
};
|
||||
let stop = stop.unwrap_or_else(|| {
|
||||
let v = start.unwrap();
|
||||
start = None;
|
||||
@ -973,7 +993,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
|
||||
let arg = arg.into_pointer_value();
|
||||
let (start, end, step) = destructure_range(ctx, arg);
|
||||
Some(calculate_len_for_slice_range(ctx, start, end, step).into())
|
||||
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
|
||||
} else {
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
|
Loading…
Reference in New Issue
Block a user