From c29cbf6ddd7ad58c0856a0ede9c74a2150f8fefc Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 5 Apr 2022 14:29:20 +0800 Subject: [PATCH] nac3core: add bound check for list slice --- nac3core/src/codegen/expr.rs | 10 +-- nac3core/src/codegen/irrt/mod.rs | 116 ++++++++++++++++++++++++++++-- nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/toplevel/builtins.rs | 26 ++++++- 4 files changed, 140 insertions(+), 14 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 50e8efc..2865ba0 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -238,6 +238,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { pub fn gen_int_ops( &mut self, + generator: &mut dyn CodeGenerator, op: &Operator, lhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>, @@ -273,7 +274,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { (Operator::RShift, _) => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), (Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), (Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").into(), - (Operator::Pow, s) => integer_power(self, lhs, rhs, s).into(), + (Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(), // special implementation? (Operator::MatMult, _) => unreachable!(), } @@ -940,9 +941,9 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right, true) + ctx.gen_int_ops(generator, op, left, right, true) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - ctx.gen_int_ops(op, left, right, false) + ctx.gen_int_ops(generator, op, left, right, false) } else if ty1 == ty2 && ctx.primitives.float == ty1 { ctx.gen_float_ops(op, left, right) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { @@ -1415,6 +1416,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let (start, end, step) = handle_slice_indices(lower, upper, step, ctx, generator, v)?; let length = calculate_len_for_slice_range( + generator, ctx, start, ctx.builder @@ -1436,8 +1438,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let res_ind = handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)?; list_slice_assignment( + generator, ctx, - generator.get_size_type(ctx.ctx), ty, res_array_ret, res_ind, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 5b55f0a..923ae9c 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -6,7 +6,7 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, + types::BasicTypeEnum, values::{IntValue, PointerValue}, AddressSpace, IntPredicate, }; @@ -34,6 +34,7 @@ pub fn load_irrt(ctx: &Context) -> Module { // repeated squaring method adapted from GNU Scientific Library: // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c pub fn integer_power<'ctx, 'a>( + generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, base: IntValue<'ctx>, exp: IntValue<'ctx>, @@ -51,7 +52,21 @@ pub fn integer_power<'ctx, 'a>( let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); ctx.module.add_function(symbol, fn_type, None) }); - // TODO: throw exception when exp < 0 + // throw exception when exp < 0 + let ge_zero = ctx.builder.build_int_compare( + IntPredicate::SGE, + exp, + exp.get_type().const_zero(), + "assert_int_pow_ge_0", + ); + ctx.make_assert( + generator, + ge_zero, + "0:ValueError", + "integer power must be positive or zero", + [None, None, None], + ctx.current_loc, + ); ctx.builder .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") .try_as_basic_value() @@ -60,6 +75,7 @@ pub fn integer_power<'ctx, 'a>( } pub fn calculate_len_for_slice_range<'ctx, 'a>( + generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, start: IntValue<'ctx>, end: IntValue<'ctx>, @@ -72,7 +88,21 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>( ctx.module.add_function(SYMBOL, fn_t, None) }); - // TODO: assert step != 0, throw exception if not + // assert step != 0, throw exception if not + let not_zero = ctx.builder.build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "step must not be zero", + [None, None, None], + ctx.current_loc, + ); ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .try_as_basic_value() @@ -129,7 +159,6 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( generator: &mut G, list: PointerValue<'ctx>, ) -> Result<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), String> { - // TODO: throw exception when step is 0 let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let one = int32.const_int(1, false); @@ -156,6 +185,21 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator)? .into_int_value(); + // assert step != 0, throw exception if not + let not_zero = ctx.builder.build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "slice step cannot be zero", + [None, None, None], + ctx.current_loc, + ); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1"); let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg"); ( @@ -231,14 +275,15 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>( /// Order of tuples assign_idx and value_idx is ('start', 'end', 'step'). /// Negative index should be handled before entering this function pub fn list_slice_assignment<'ctx, 'a>( + generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - size_ty: IntType<'ctx>, ty: BasicTypeEnum<'ctx>, dest_arr: PointerValue<'ctx>, dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_arr: PointerValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { + let size_ty = generator.get_size_type(ctx.ctx); let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::Generic); let int32 = ctx.ctx.i32_type(); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); @@ -282,8 +327,67 @@ pub fn list_slice_assignment<'ctx, 'a>( let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32"); // index in bound and positive should be done - // TODO: assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and + // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // throw exception if not satisfied + let src_end = ctx.builder + .build_select( + ctx.builder.build_int_compare( + inkwell::IntPredicate::SLT, + src_idx.2, + zero, + "is_neg", + ), + ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"), + ctx.builder.build_int_add(src_idx.1, one, "e_add_one"), + "final_e", + ) + .into_int_value(); + let dest_end = ctx.builder + .build_select( + ctx.builder.build_int_compare( + inkwell::IntPredicate::SLT, + dest_idx.2, + zero, + "is_neg", + ), + ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"), + ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"), + "final_e", + ) + .into_int_value(); + let src_slice_len = + calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); + let dest_slice_len = + calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); + let src_eq_dest = ctx.builder.build_int_compare( + IntPredicate::EQ, + src_slice_len, + dest_slice_len, + "slice_src_eq_dest", + ); + let src_slt_dest = ctx.builder.build_int_compare( + IntPredicate::SLT, + src_slice_len, + dest_slice_len, + "slice_src_slt_dest", + ); + let dest_step_eq_one = ctx.builder.build_int_compare( + IntPredicate::EQ, + dest_idx.2, + dest_idx.2.get_type().const_int(1, false), + "slice_dest_step_eq_one", + ); + let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1"); + let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond"); + ctx.make_assert( + generator, + cond, + "0:ValueError", + "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", + [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], + ctx.current_loc, + ); + let new_len = { let args = vec![ dest_idx.0.into(), // dest start idx diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 6fddfb2..ad17b93 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -134,8 +134,8 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( }; let src_ind = handle_slice_indices(&None, &None, &None, ctx, generator, value)?; list_slice_assignment( + generator, ctx, - generator.get_size_type(ctx.ctx), ty, ls, (start, end, step), diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 9491d5e..85c2532 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -670,8 +670,28 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { step = Some(arg.1.clone().to_basic_value_enum(ctx, generator)?); } } - // TODO: error when step == 0 - let step = step.unwrap_or_else(|| int32.const_int(1, false).into()); + let step = match step { + Some(step) => { + let step = step.into_int_value(); + // assert step != 0, throw exception if not + let not_zero = ctx.builder.build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "range() step must not be zero", + [None, None, None], + ctx.current_loc, + ); + step + } + None => int32.const_int(1, false), + }; let stop = stop.unwrap_or_else(|| { let v = start.unwrap(); start = None; @@ -973,7 +993,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(if ctx.unifier.unioned(arg_ty, range_ty) { let arg = arg.into_pointer_value(); let (start, end, step) = destructure_range(ctx, arg); - Some(calculate_len_for_slice_range(ctx, start, end, step).into()) + Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) } else { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero();